Merge remote-tracking branch 'upstream/master' into xtensa-fusion-f1

This commit is contained in:
Advait Jain 2020-12-08 21:45:17 -08:00
commit 6b24408403
281 changed files with 6002 additions and 4504 deletions

View File

@ -61,6 +61,7 @@
* Added support for saved model's session initializer through * Added support for saved model's session initializer through
`TFLiteConverter.from_saved_model`. `TFLiteConverter.from_saved_model`.
* Added dynamic range quantization support for the BatchMatMul op. * Added dynamic range quantization support for the BatchMatMul op.
* Added DEPTH_TO_SPACE support in Post training quantization.
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently * Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
only supports float32 input. only supports float32 input.
* TFLite Supports SingatureDef: * TFLite Supports SingatureDef:

View File

@ -145,6 +145,8 @@ if _running_from_pip_package():
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins') _plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _os.path.exists(_plugin_dir): if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir) _ll.load_library(_plugin_dir)
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
# Add module aliases # Add module aliases
if hasattr(_current_module, 'keras'): if hasattr(_current_module, 'keras'):

View File

@ -155,6 +155,8 @@ if _running_from_pip_package():
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins') _plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _os.path.exists(_plugin_dir): if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir) _ll.load_library(_plugin_dir)
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
# Delete modules that should be hidden from dir(). # Delete modules that should be hidden from dir().
# Don't fail if these modules are not available. # Don't fail if these modules are not available.

View File

@ -684,7 +684,10 @@ tf_cc_test(
name = "c_api_experimental_test", name = "c_api_experimental_test",
size = "medium", size = "medium",
srcs = ["c_api_experimental_test.cc"], srcs = ["c_api_experimental_test.cc"],
data = ["testdata/tf_record"], data = [
"testdata/tf_record",
"//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
],
linkopts = select({ linkopts = select({
"//tensorflow:macos": ["-headerpad_max_install_names"], "//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [], "//conditions:default": [],
@ -704,6 +707,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )

View File

@ -37,7 +37,9 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/strcat.h"
@ -630,6 +632,9 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
namespace tensorflow { namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
} // namespace tensorflow } // namespace tensorflow
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
@ -743,3 +748,45 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable) { TF_ImportGraphDefOptions* opts, unsigned char enable) {
opts->opts.validate_colocation_constraints = enable; opts->opts.validate_colocation_constraints = enable;
} }
// Load a Pluggable Device library.
// On success, returns the handle to library in result and return OK from the
// function. Otherwise return nullptr in result and error Status from the
// function.
//
// If `library_filename` has already been loaded, we return a cached handle.
// Device and Kernels/Ops are registered as globals when a library is loaded
// for the first time.
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
status->status = tensorflow::errors::Unimplemented(
"PluggableDevice plugin functionality is not supported on mobile");
return nullptr;
#else
TF_Library* lib_handle = new TF_Library;
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<std::string, void*>* loaded_libs =
new std::unordered_map<std::string, void*>();
tensorflow::Env* env = tensorflow::Env::Default();
{
tensorflow::mutex_lock lock(mu);
auto it = loaded_libs->find(library_filename);
if (it != loaded_libs->end()) {
lib_handle->lib_handle = it->second;
} else {
status->status =
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
if (!status->status.ok()) {
delete lib_handle;
return nullptr;
}
}
return lib_handle;
}
#endif
}
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
delete lib_handle;
}

View File

@ -304,6 +304,27 @@ TF_CAPI_EXPORT extern void
TF_ImportGraphDefOptionsSetValidateColocationConstraints( TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable); TF_ImportGraphDefOptions* opts, unsigned char enable);
// Load the library specified by library_filename and register the pluggable
// device and related kernels present in that library. This function is not
// supported on embedded on mobile and embedded platforms and will fail if
// called.
//
// Pass "library_filename" to a platform-specific mechanism for dynamically
// loading a library. The rules for determining the exact location of the
// library are platform-specific and are not documented here.
//
// On success, returns the newly created library handle and places OK in status.
// The caller owns the library handle.
//
// On failure, returns nullptr and places an error status in status.
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
const char* library_filename, TF_Status* status);
// Frees the memory associated with the library handle.
// Does NOT unload the library.
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
TF_Library* lib_handle);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@ -234,5 +235,22 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
TF_DeleteTensor(tensor_1X6); TF_DeleteTensor(tensor_1X6);
} }
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
// Load the library.
TF_Status* status = TF_NewStatus();
string lib_path =
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
"tensorflow", "c", "experimental", "stream_executor", "test",
"test_pluggable_device.so"));
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
ASSERT_EQ(TF_OK, code) << status_msg;
TF_DeletePluggableDeviceLibraryHandle(lib);
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -213,7 +213,11 @@ TF_Function* TF_GraphToFunctionWithControlOutputs(
TF_DeleteFunction(tf_function); TF_DeleteFunction(tf_function);
return nullptr; return nullptr;
} }
tf_function->graph_with_debug_info = &fn_body->graph;
for (const Node* n : fn_body->graph.nodes()) {
tf_function->stack_traces[n->name()] = n->GetStackTrace();
}
return tf_function; return tf_function;
} }

View File

@ -157,9 +157,7 @@ struct TF_DeviceList {
struct TF_Function { struct TF_Function {
tensorflow::FunctionDef fdef; tensorflow::FunctionDef fdef;
tensorflow::StackTracesMap stack_traces;
// Graph with nodes with debug stack traces.
const tensorflow::Graph* graph_with_debug_info = nullptr;
}; };
struct TF_ApiDefMap { struct TF_ApiDefMap {

View File

@ -749,8 +749,8 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) { TF_Status* status) {
status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithDebugInfo( status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
function->fdef, function->graph_with_debug_info); function->fdef, function->stack_traces);
} }
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,

View File

@ -111,11 +111,11 @@ class ImmediateExecutionContext : public AbstractContext {
// already exists. // already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Same as `AddFunctionDef`, and additionally saves a pointer to the Graph // Same as `AddFunctionDef`, but additionally saves the `stack_traces` under
// which has nodes containing stack traces for the nodes in `fdef`. Assumes // the key of the function definition name (to be retrieved during function
// `graph` is alive while the function is alive. // instantiation).
virtual Status AddFunctionDefWithDebugInfo(const FunctionDef& fdef, virtual Status AddFunctionDefWithStackTraces(
const Graph* graph) = 0; const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0;
// Find and return a added function by its name. // Find and return a added function by its name.
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0; virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;

View File

@ -0,0 +1,17 @@
# Description:
# test for stream_executor
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_shared_object",
)
package(
licenses = ["notice"], # Apache 2.0
)
tf_cc_shared_object(
name = "test_pluggable_device.so",
srcs = ["test_pluggable_device.cc"],
visibility = ["//tensorflow/c:__subpackages__"],
deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
)

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -13,16 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_ #include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#define TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
#include "pybind11/pybind11.h" void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
#include "tensorflow/compiler/xla/statusor.h" TF_Status* const status) {
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
namespace xla { params->platform->name = "GPU";
params->platform->type = "XGPU";
xla::StatusOr<pybind11::object> Bfloat16Dtype(); }
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
using tensorflow::errors::InvalidArgument;
// This file forms the basis of a stable ABI for third-party kernel // This file forms the basis of a stable ABI for third-party kernel
// implementations. It is crucial that changes to this file are made cautiously // implementations. It is crucial that changes to this file are made cautiously
// and with a focus on maintaining both source and binary compatibility. // and with a focus on maintaining both source and binary compatibility.
@ -87,9 +88,25 @@ void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
} }
#undef CASE #undef CASE
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow
namespace {
const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
const char* attr_name,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
const tensorflow::AttrValue* attr =
::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
if (attr == nullptr) {
status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
"' has no attr named '", attr_name, "'.");
}
return attr;
}
} // namespace
void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder, void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
const char* attr_name, const char* attr_name,
const TF_DataType type, const TF_DataType type,
@ -257,7 +274,81 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
cc_ctx->CtxFailure(s); cc_ctx->CtxFailure(s);
} }
#define DEFINE_TF_GETATTR(func, c_type, cc_type) \ void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
const char* attr_name,
int32_t* list_size,
int32_t* total_size,
TF_Status* status) {
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
if (!status->status.ok()) {
*list_size = -1;
*total_size = -1;
return;
}
switch (attr->value_case()) {
#define SINGLE_CASE(kK, attr_type, size_expr) \
case tensorflow::AttrValue::kK: \
*list_size = -1; \
*total_size = size_expr; \
break;
SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
SINGLE_CASE(kI, TF_ATTR_INT, -1);
SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
SINGLE_CASE(kShape, TF_ATTR_SHAPE,
attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
#undef SINGLE_CASE
case tensorflow::AttrValue::kList:
*list_size = 0;
*total_size = -1;
#define LIST_CASE(field, attr_type, ...) \
if (attr->list().field##_size() > 0) { \
*list_size = attr->list().field##_size(); \
__VA_ARGS__; \
break; \
}
LIST_CASE(
s, TF_ATTR_STRING, *total_size = 0;
for (int i = 0; i < attr->list().s_size();
++i) { *total_size += attr->list().s(i).size(); });
LIST_CASE(i, TF_ATTR_INT);
LIST_CASE(f, TF_ATTR_FLOAT);
LIST_CASE(b, TF_ATTR_BOOL);
LIST_CASE(type, TF_ATTR_TYPE);
LIST_CASE(
shape, TF_ATTR_SHAPE, *total_size = 0;
for (int i = 0; i < attr->list().shape_size(); ++i) {
const auto& s = attr->list().shape(i);
*total_size += s.unknown_rank() ? 0 : s.dim_size();
});
LIST_CASE(tensor, TF_ATTR_TENSOR);
LIST_CASE(tensor, TF_ATTR_FUNC);
#undef LIST_CASE
break;
case tensorflow::AttrValue::kPlaceholder:
*list_size = -1;
*total_size = -1;
break;
case tensorflow::AttrValue::kFunc:
*list_size = -1;
*total_size = -1;
break;
case tensorflow::AttrValue::VALUE_NOT_SET:
status->status =
InvalidArgument("Attribute '", attr_name, "' has no value set");
break;
}
}
#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \
void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \ void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
const char* attr_name, \ const char* attr_name, \
c_type* val, TF_Status* status) { \ c_type* val, TF_Status* status) { \
@ -269,10 +360,84 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
if (s.ok()) { \ if (s.ok()) { \
*val = static_cast<c_type>(v); \ *val = static_cast<c_type>(v); \
} \ } \
} \
void TF_OpKernelConstruction_GetAttr##func##List( \
TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \
int max_vals, TF_Status* status) { \
TF_SetStatus(status, TF_OK, ""); \
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \
if (!status->status.ok()) return; \
if (attr->value_case() != tensorflow::AttrValue::kList) { \
status->status = \
InvalidArgument("Value for '", attr_name, "' is not a list."); \
return; \
} \
status->status = \
tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \
if (!status->status.ok()) return; \
const auto len = std::min(max_vals, attr->list().list_field##_size()); \
for (int i = 0; i < len; ++i) { \
vals[i] = static_cast<c_type>(attr->list().list_field(i)); \
} \
} }
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType) DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t) DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i)
DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i)
DEFINE_TF_GETATTR(Float, float, float, "float", f)
DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
const char* attr_name, char* value,
size_t max_length,
TF_Status* status) {
std::string v;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
::tensorflow::Set_TF_Status_from_Status(status, s);
if (!status->status.ok()) return;
if (max_length <= 0) {
return;
}
std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
}
void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
const char* attr_name,
char** values, size_t* lengths,
int max_values, void* storage,
size_t storage_size,
TF_Status* status) {
std::vector<std::string> v;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
::tensorflow::Set_TF_Status_from_Status(status, s);
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(v.size()));
char* p = static_cast<char*>(storage);
for (int i = 0; i < len; ++i) {
const std::string& s = v[i];
values[i] = p;
lengths[i] = s.size();
if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
status->status = InvalidArgument(
"Not enough storage to hold the requested list of strings");
return;
}
memcpy(values[i], s.data(), s.size());
p += s.size();
}
}
bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
const char* attr_name, TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
return cc_ctx->HasAttr(attr_name);
}
TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) { TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx); auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);

View File

@ -184,6 +184,24 @@ TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType(
// Returns the step ID of the given context. // Returns the step ID of the given context.
TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx); TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx);
// Get the list_size and total_size of the attribute `attr_name` of `oper`.
// list_size - the length of the list.
// total_size - total size of the list.
// (1) If attr_type == TF_ATTR_STRING
// then total_size is the cumulative byte size
// of all the strings in the list.
// (3) If attr_type == TF_ATTR_SHAPE
// then total_size is the number of dimensions
// of the shape valued attribute, or -1
// if its rank is unknown.
// (4) If attr_type == TF_ATTR_SHAPE
// then total_size is the cumulative number
// of dimensions of all shapes in the list.
// (5) Otherwise, total_size is undefined.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size,
int32_t* total_size, TF_Status* status);
// Interprets the named kernel construction attribute as a TF_DataType and // Interprets the named kernel construction attribute as a TF_DataType and
// places it into *val. *status is set to TF_OK. // places it into *val. *status is set to TF_OK.
// //
@ -202,6 +220,112 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val, TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
TF_Status* status); TF_Status* status);
// Interprets the named kernel construction attribute as int64_t and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// int64, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64(
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val,
TF_Status* status);
// Interprets the named kernel construction attribute as float and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// float, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat(
TF_OpKernelConstruction* ctx, const char* attr_name, float* val,
TF_Status* status);
// Interprets the named kernel construction attribute as bool and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// bool, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val,
TF_Status* status);
// Interprets the named kernel construction attribute as string and
// places it into *val. `val` must
// point to an array of length at least `max_length` (ideally set to
// total_size from TF_OpKernelConstruction_GetAttrSize(ctx,
// attr_name, list_size, total_size)). *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// string, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString(
TF_OpKernelConstruction* ctx, const char* attr_name, char* val,
size_t max_length, TF_Status* status);
// Interprets the named kernel construction attribute as a TF_DataType array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as int32_t array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as int64_t array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List(
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as float array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList(
TF_OpKernelConstruction* ctx, const char* attr_name, float* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as bool array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as string array and fills
// in `vals` and `lengths`, each of which must point to an array of length at
// least `max_values`. *status is set to TF_OK. The elements of values will
// point to addresses in `storage` which must be at least `storage_size` bytes
// in length. Ideally, max_values would be set to list_size and `storage` would
// be at least total_size, obtained from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList(
TF_OpKernelConstruction* ctx, const char* attr_name, char** vals,
size_t* lengths, int max_values, void* storage, size_t storage_size,
TF_Status* status);
// Return true if the kernel construction has the attr_name
TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status);
// Returns the unique operation name for this OpKernel. // Returns the unique operation name for this OpKernel.
TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName( TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
TF_OpKernelConstruction* ctx); TF_OpKernelConstruction* ctx);

View File

@ -161,6 +161,337 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
ASSERT_TRUE(delete_called); ASSERT_TRUE(delete_called);
} }
// REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases.
// Registers two ops, each with a single attribute called 'Attr'.
// The attribute in one op will have a type 'type', the other
// will have list(type).
#define ATTR_TEST_REGISTER_OP(name, type) \
REGISTER_OP("TestKernelAttr" #name) \
.Attr("Attr: " #type) \
.SetShapeFn(tensorflow::shape_inference::UnknownShape); \
REGISTER_OP("TestKernelAttr" #name "List") \
.Attr("Attr: list(" #type ")") \
.SetShapeFn(tensorflow::shape_inference::UnknownShape)
ATTR_TEST_REGISTER_OP(String, string);
ATTR_TEST_REGISTER_OP(Int, int);
ATTR_TEST_REGISTER_OP(Float, float);
ATTR_TEST_REGISTER_OP(Bool, bool);
ATTR_TEST_REGISTER_OP(Type, type);
#undef ATTR_TEST_REGISTER_OP
// Helper macros for the TF_OpKernelConstruction_GetAttr* tests.
#define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \
do { \
int32_t list_size, total_size; \
TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, \
&total_size, status); \
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); \
EXPECT_EQ(expected_list_size, list_size); \
EXPECT_EQ(expected_total_size, total_size); \
} while (0)
typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*);
class TestKernelAttr : public ::testing::Test {
public:
TestKernelAttr() {}
~TestKernelAttr() {}
std::unique_ptr<OpKernel> GetFakeKernelWithAttr(const char* op_name,
AttrValue v, Status* status) {
NodeDef def;
def.set_op(op_name);
def.set_name("FakeNode");
def.set_device("FakeDevice");
(*def.mutable_attr())["Attr"] = v;
return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1,
status);
}
void SetAttr(MyCreateFuncWithAttr MyCreateFuncAttr, const char* op_name,
AttrValue& v) {
TF_KernelBuilder* builder = TF_NewKernelBuilder(
op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc);
{
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder("FakeNode", builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status);
}
Status status;
std::unique_ptr<OpKernel> kernel =
GetFakeKernelWithAttr(op_name, v, &status);
TF_EXPECT_OK(status);
ASSERT_NE(nullptr, kernel.get());
kernel->Compute(nullptr);
ASSERT_TRUE(delete_called);
}
};
TEST_F(TestKernelAttr, String) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
std::unique_ptr<char[]> val(new char[5]);
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ 5);
TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(),
/*max_length*/ 5, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ("bunny", string(static_cast<const char*>(val.get()), 5));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_s("bunny");
SetAttr(my_create_func, "TestKernelAttrString", v);
}
TEST_F(TestKernelAttr, StringList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
std::vector<string> list = {"bugs", "bunny", "duck"};
int list_total_size = 0;
for (const auto& s : list) {
list_total_size += s.size();
}
TF_Status* status = TF_NewStatus();
std::unique_ptr<char*[]> values(new char*[list.size()]);
std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
std::unique_ptr<char[]> storage(new char[list_total_size]);
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(),
/*expected_total_size*/ list_total_size);
TF_OpKernelConstruction_GetAttrStringList(
ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(),
list_total_size, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
for (size_t i = 0; i < list.size(); ++i) {
EXPECT_EQ(list[i].size(), lens[i]) << i;
EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
<< i;
}
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in = gtl::ArraySlice<StringPiece>({"bugs", "bunny", "duck"});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrStringList", v);
}
TEST_F(TestKernelAttr, Int) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
int64_t val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(1234, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_i(1234);
SetAttr(my_create_func, "TestKernelAttrInt", v);
}
TEST_F(TestKernelAttr, IntList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const int64_t list[] = {1, 2, 3, 4};
const size_t list_size = TF_ARRAYSIZE(list);
int64_t values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in = gtl::ArraySlice<int64>({1, 2, 3, 4});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrIntList", v);
}
TEST_F(TestKernelAttr, Float) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
float val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FLOAT_EQ(2.718, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_f(2.718);
SetAttr(my_create_func, "TestKernelAttrFloat", v);
}
TEST_F(TestKernelAttr, FloatList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const float list[] = {1.414, 2.718, 3.1415};
const size_t list_size = TF_ARRAYSIZE(list);
float values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in = gtl::ArraySlice<float>({1.414, 2.718, 3.1415});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrFloatList", v);
}
TEST_F(TestKernelAttr, Bool) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
unsigned char val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(1, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_b(1);
SetAttr(my_create_func, "TestKernelAttrBool", v);
}
TEST_F(TestKernelAttr, BoolList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const unsigned char list[] = {1, 0, 1, 0};
const size_t list_size = TF_ARRAYSIZE(list);
unsigned char values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in = gtl::ArraySlice<bool>({1, 0, 1, 0});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrBoolList", v);
}
TEST_F(TestKernelAttr, Type) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
TF_DataType val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(TF_FLOAT, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_type(DT_FLOAT);
SetAttr(my_create_func, "TestKernelAttrType", v);
}
TEST_F(TestKernelAttr, TypeList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
const size_t list_size = TF_ARRAYSIZE(list);
TF_DataType values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in =
gtl::ArraySlice<DataType>({DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrTypeList", v);
}
#undef EXPECT_TF_SIZE
class DummyDevice : public DeviceBase { class DummyDevice : public DeviceBase {
public: public:
explicit DummyDevice(Env* env) : DeviceBase(env) {} explicit DummyDevice(Env* env) : DeviceBase(env) {}

View File

@ -65,6 +65,7 @@ MAP_HLO_TO_LHLO(MinOp);
MAP_HLO_TO_LHLO(MulOp); MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp); MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(NotOp); MAP_HLO_TO_LHLO(NotOp);
MAP_HLO_TO_LHLO(OrOp);
MAP_HLO_TO_LHLO(RealOp); MAP_HLO_TO_LHLO(RealOp);
MAP_HLO_TO_LHLO(ReduceOp); MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(ReshapeOp); MAP_HLO_TO_LHLO(ReshapeOp);
@ -81,6 +82,7 @@ MAP_HLO_TO_LHLO(SqrtOp);
MAP_HLO_TO_LHLO(SubOp); MAP_HLO_TO_LHLO(SubOp);
MAP_HLO_TO_LHLO(TanhOp); MAP_HLO_TO_LHLO(TanhOp);
MAP_HLO_TO_LHLO(TransposeOp); MAP_HLO_TO_LHLO(TransposeOp);
MAP_HLO_TO_LHLO(XorOp);
#undef MAP_HLO_TO_LHLO #undef MAP_HLO_TO_LHLO

View File

@ -481,6 +481,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
return nullptr; return nullptr;
} }
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::OrOp>{}(
loc, result_types, args, b);
}
template <> template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc, inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
@ -580,6 +589,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
loc, result_types, args, b); loc, result_types, args, b);
} }
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
loc, result_types, args, b);
}
} // namespace impl } // namespace impl
struct HloOpToStdScalarOp { struct HloOpToStdScalarOp {

View File

@ -629,6 +629,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::MulOp>, HloToLhloOpConverter<mhlo::MulOp>,
HloToLhloOpConverter<mhlo::NegOp>, HloToLhloOpConverter<mhlo::NegOp>,
HloToLhloOpConverter<mhlo::NotOp>, HloToLhloOpConverter<mhlo::NotOp>,
HloToLhloOpConverter<mhlo::OrOp>,
HloToLhloOpConverter<mhlo::RealOp>, HloToLhloOpConverter<mhlo::RealOp>,
HloToLhloOpConverter<mhlo::RemOp>, HloToLhloOpConverter<mhlo::RemOp>,
HloToLhloOpConverter<mhlo::RsqrtOp>, HloToLhloOpConverter<mhlo::RsqrtOp>,
@ -644,6 +645,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::SubOp>, HloToLhloOpConverter<mhlo::SubOp>,
HloToLhloOpConverter<mhlo::TanhOp>, HloToLhloOpConverter<mhlo::TanhOp>,
HloToLhloOpConverter<mhlo::TransposeOp>, HloToLhloOpConverter<mhlo::TransposeOp>,
HloToLhloOpConverter<mhlo::XorOp>,
HloToLhloReduceOpConverter, HloToLhloReduceOpConverter,
HloToLhloReturnOpConverter, HloToLhloReturnOpConverter,
HloToLhloTensorLoadOpConverter, HloToLhloTensorLoadOpConverter,

View File

@ -927,12 +927,14 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::ExpOp>, PointwiseToLinalgConverter<lmhlo::ExpOp>,
PointwiseToLinalgConverter<lmhlo::FloorOp>, PointwiseToLinalgConverter<lmhlo::FloorOp>,
PointwiseToLinalgConverter<lmhlo::ImagOp>, PointwiseToLinalgConverter<lmhlo::ImagOp>,
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
PointwiseToLinalgConverter<lmhlo::LogOp>, PointwiseToLinalgConverter<lmhlo::LogOp>,
PointwiseToLinalgConverter<lmhlo::MaxOp>, PointwiseToLinalgConverter<lmhlo::MaxOp>,
PointwiseToLinalgConverter<lmhlo::MinOp>, PointwiseToLinalgConverter<lmhlo::MinOp>,
PointwiseToLinalgConverter<lmhlo::MulOp>, PointwiseToLinalgConverter<lmhlo::MulOp>,
PointwiseToLinalgConverter<lmhlo::NegOp>, PointwiseToLinalgConverter<lmhlo::NegOp>,
PointwiseToLinalgConverter<lmhlo::NotOp>, PointwiseToLinalgConverter<lmhlo::NotOp>,
PointwiseToLinalgConverter<lmhlo::OrOp>,
PointwiseToLinalgConverter<lmhlo::RealOp>, PointwiseToLinalgConverter<lmhlo::RealOp>,
PointwiseToLinalgConverter<lmhlo::RemOp>, PointwiseToLinalgConverter<lmhlo::RemOp>,
PointwiseToLinalgConverter<lmhlo::RsqrtOp>, PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
@ -945,7 +947,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::SqrtOp>, PointwiseToLinalgConverter<lmhlo::SqrtOp>,
PointwiseToLinalgConverter<lmhlo::SubOp>, PointwiseToLinalgConverter<lmhlo::SubOp>,
PointwiseToLinalgConverter<lmhlo::TanhOp>, PointwiseToLinalgConverter<lmhlo::TanhOp>,
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>, PointwiseToLinalgConverter<lmhlo::XorOp>,
ReduceConverter, ReduceConverter,
ReshapeOpConverter<lmhlo::ReshapeOp>, ReshapeOpConverter<lmhlo::ReshapeOp>,
ReverseConverter<lmhlo::ReverseOp>, ReverseConverter<lmhlo::ReverseOp>,
@ -1042,12 +1044,14 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::ExpOp, false>, PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::FloorOp, false>, PointwiseToLinalgConverter<mhlo::FloorOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>, PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>, PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>, PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>, PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>, PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>, PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::NotOp, false>, PointwiseToLinalgConverter<mhlo::NotOp, false>,
PointwiseToLinalgConverter<mhlo::OrOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>, PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>, PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>, PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
@ -1055,11 +1059,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>, PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>, PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>, PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
PointwiseToLinalgConverter<mhlo::SignOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>, PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>, PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>, PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>, PointwiseToLinalgConverter<mhlo::TanhOp, false>,
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>, PointwiseToLinalgConverter<mhlo::XorOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>, ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>, ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context); TransposeConverter<mhlo::TransposeOp, false>>(context);

View File

@ -42,11 +42,11 @@ namespace {
sep fn(SqrtOp) sep fn(TanhOp) sep fn(SqrtOp) sep fn(TanhOp)
// TODO(herhut): Generate these out of op definitions. // TODO(herhut): Generate these out of op definitions.
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \ #define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
fn(AddOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) sep fn(MaxOp) \ fn(AddOp) sep fn(AndOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) \
sep fn(MinOp) sep fn(MulOp) sep fn(PowOp) sep fn(RemOp) \ sep fn(MaxOp) sep fn(MinOp) sep fn(MulOp) sep fn(OrOp) sep fn(PowOp) \
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \ sep fn(RemOp) sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
// TODO(herhut): Generate these out of op definitions. // TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \

View File

@ -316,6 +316,20 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// ----- // -----
// CHECK-LABEL: func @and
func @and(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
%tensor_result = "mhlo.and"(%tensor_operand0, %tensor_operand1)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.and"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @ceil // CHECK-LABEL: func @ceil
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -389,6 +403,20 @@ func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
// ----- // -----
// CHECK-LABEL: func @or
func @or(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
%tensor_result = "mhlo.or"(%tensor_operand0, %tensor_operand1)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.or"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @rsqrt // CHECK-LABEL: func @rsqrt
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -480,7 +508,8 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// ----- // -----
// CHECK-LABEL: func @remainder // CHECK-LABEL: func @remainder
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs) %tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
@ -492,6 +521,20 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
// ----- // -----
// CHECK-LABEL: func @xor
func @xor(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
%tensor_result = "mhlo.xor"(%tensor_operand0, %tensor_operand1)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.xor"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// Dynamic shape binary element-wise operation. // Dynamic shape binary element-wise operation.
// CHECK-LABEL: func @add_dyn // CHECK-LABEL: func @add_dyn
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) { func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {

View File

@ -194,6 +194,30 @@ func @integer_and(%lhs: tensor<2x2xi32>,
// ----- // -----
// CHECK-LABEL: func @integer_or
func @integer_or(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: or
%0 = "mhlo.or"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @integer_xor
func @integer_xor(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: xor
%0 = "mhlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @float_cmp // CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: tensor<2x2xf32>, func @float_cmp(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) { %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {

View File

@ -509,7 +509,7 @@ Operation* BuildVariableOp(const tflite::TensorT& tensor,
return op.getOperation(); return op.getOperation();
} }
auto op = builder.create<tfl::ConstOp>(loc, value); auto op = builder.create<tfl::ConstOp>(loc, value);
if (!tensor.quantization->min.empty()) { if (tensor.quantization && !tensor.quantization->min.empty()) {
if (auto stats_op = if (auto stats_op =
ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) { ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
return stats_op; return stats_op;

View File

@ -1977,6 +1977,7 @@ cc_library(
hdrs = ["utils/bridge_logger.h"], hdrs = ["utils/bridge_logger.h"],
deps = [ deps = [
":dump_mlir_util", ":dump_mlir_util",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",

View File

@ -5028,6 +5028,26 @@ Table initializer that takes two tensors for keys and values respectively.
TF_DerivedOperandTypeAttr Tkey = TF_DerivedOperandTypeAttr<1>; TF_DerivedOperandTypeAttr Tkey = TF_DerivedOperandTypeAttr<1>;
} }
def TF_InplaceAddOp : TF_Op<"InplaceAdd", [AllTypesMatch<["x", "y"]>, NoSideEffect]> {
let summary = "Adds v into specified rows of x.";
let description = [{
Computes y = x; y[i, :] += v; return y.
}];
let arguments = (ins
TF_Tensor:$x,
TF_Int32Tensor:$i,
TF_Tensor:$v
);
let results = (outs
TF_Tensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> { def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
let summary = "Updates specified rows 'i' with values 'v'."; let summary = "Updates specified rows 'i' with values 'v'.";
@ -5374,6 +5394,37 @@ def TF_IteratorV2Op : TF_Op<"IteratorV2", []> {
); );
} }
def TF_KthOrderStatisticOp : TF_Op<"KthOrderStatistic", [NoSideEffect]> {
let summary = "Computes the Kth order statistic of a data set. The current";
let description = [{
implementation uses a binary search requiring exactly 32 passes over
the input data. The running time is linear with respect to input
size. The median-of-medians algorithm is probably faster, but is
difficult to implement efficiently in XLA. The implementation imposes
a total ordering on floats. The ordering is consistent with the usual
partial order. Positive NaNs are greater than positive
infinity. Negative NaNs are less than negative infinity. NaNs with
distinct payloads are treated as distinct. Subnormal numbers are
preserved (not flushed to zero). Positive infinity is greater than all
numbers. Negative infinity is less than all numbers. Positive is
greater than negative zero. There are less than k values greater than
the kth order statistic. There are at least k values greater than or
equal to the Kth order statistic. The semantics are not the same as
top_k_unique.
}];
let arguments = (ins
TF_Float32Tensor:$input,
I64Attr:$k
);
let results = (outs
TF_Float32Tensor:$output
);
}
def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> { def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> {
let summary = "L2 Loss."; let summary = "L2 Loss.";
@ -6505,6 +6556,27 @@ iterator in `iterator` to the first element of `dataset`.
let results = (outs); let results = (outs);
} }
def TF_MakeUniqueOp : TF_Op<"MakeUnique", [NoSideEffect]> {
let summary = [{
Make all elements in the non-Batch dimension unique, but \"close\" to
}];
let description = [{
their initial value. Never returns a sub-normal number. Never returns
zero. The sign of each input element is always identical to the sign
of the corresponding output element. Behavior for infinite elements is
undefined. Behavior for subnormal elements is undefined.
}];
let arguments = (ins
TF_Float32Tensor:$input
);
let results = (outs
TF_Float32Tensor:$output
);
}
def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
let summary = [{ let summary = [{
Multiply the matrix "a" by the matrix "b". Multiply the matrix "a" by the matrix "b".
@ -15234,6 +15306,36 @@ array([[1, 2, 3, 1, 2, 3],
let hasFolder = 1; let hasFolder = 1;
} }
def TF_TopKUniqueOp : TF_Op<"TopKUnique", [NoSideEffect]> {
let summary = "Returns the TopK unique values in the array in sorted order.";
let description = [{
The running time is proportional to the product of K and the input
size. Sorting the whole array is more efficient for sufficiently large
values of K. The median-of-medians algorithm is probably faster, but
difficult to implement efficiently in XLA. If there are fewer than K
unique numbers (not NANs), the results are padded with negative
infinity. NaNs are never returned. Subnormal numbers are flushed to
zero. If an element appears at multiple indices, the highest index is
returned. If a TopK element never appears in the input due to padding
values, the indices are padded with negative one. If a padding value
appears in the input and padding is needed, the highest index of the
padding value will be returned. The semantics are not the same as
kth_order_statistic.
}];
let arguments = (ins
TF_Float32Tensor:$input,
I64Attr:$k
);
let results = (outs
TF_Float32Tensor:$topk,
TF_Int32Tensor:$topk_indices
);
}
def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> { def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> {
let summary = [{ let summary = [{
Finds values and indices of the `k` largest elements for the last dimension. Finds values and indices of the `k` largest elements for the last dimension.
@ -15269,6 +15371,29 @@ If two elements are equal, the lower-index element appears first.
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
} }
def TF_TopKWithUniqueOp : TF_Op<"TopKWithUnique", [NoSideEffect]> {
let summary = "Returns the TopK values in the array in sorted order.";
let description = [{
This is a combination of MakeUnique and TopKUnique. The returned top-K will
have its lower bits replaced by iota, thus it will be close to the original
value but not exactly the same. The running time is proportional to the product
of K and the input size. NaNs are never returned. Subnormal numbers are flushed
to zero.
}];
let arguments = (ins
TF_Float32Tensor:$input,
I64Attr:$k
);
let results = (outs
TF_Float32Tensor:$topk,
TF_Int32Tensor:$topk_indices
);
}
def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> { def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> {
let summary = "Shuffle dimensions of x according to a permutation."; let summary = "Shuffle dimensions of x according to a permutation.";

View File

@ -532,7 +532,11 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
auto ranked_type = type.dyn_cast<RankedTensorType>(); auto ranked_type = type.dyn_cast<RankedTensorType>();
if (!ranked_type) return {}; if (!ranked_type) return {};
auto output_type = getType().cast<ShapedType>(); // DenseIntElementsAttr::get requires the output type be ranked with static
// shape.
auto output_type = getType().dyn_cast<RankedTensorType>();
if (!output_type || !output_type.hasStaticShape()) return {};
int32_t rank = ranked_type.getRank(); int32_t rank = ranked_type.getRank();
return DenseIntElementsAttr::get(output_type, rank); return DenseIntElementsAttr::get(output_type, rank);
} }

View File

@ -904,6 +904,20 @@ func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> {
return %0 : tensor<i32> return %0 : tensor<i32>
} }
// CHECK-LABEL: testRankOfRankedTensorUnrankedOutput
func @testRankOfRankedTensorUnrankedOutput(%arg0 : tensor<4x3x2xf32>) -> tensor<*xi32> {
// Regression test to make sure we don't crash in this case.
%0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<*xi32>
return %0 : tensor<*xi32>
}
// CHECK-LABEL: testRankOfRankedTensorDynamicShapeOutput
func @testRankOfRankedTensorDynamicShapeOutput(%arg0 : tensor<4x3x2xf32>) -> tensor<?xi32> {
// Regression test to make sure we don't crash in this case.
%0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// CHECK-LABEL: @foldFill // CHECK-LABEL: @foldFill
func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) { func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
%0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32> %0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>

View File

@ -1627,8 +1627,10 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) {
return success(); return success();
} }
int64_t producer = producer_or.ValueOrDie(); int64_t producer = producer_or.ValueOrDie();
// TODO(jpienaar): Clean up propagate_caller_callee_constants if it is no
// longer needed.
ShapeInference context(producer, module.getContext(), ShapeInference context(producer, module.getContext(),
/*propagate_caller_callee_constants=*/true); /*propagate_caller_callee_constants=*/false);
if (auto main = module.lookupSymbol<mlir::FuncOp>("main")) if (auto main = module.lookupSymbol<mlir::FuncOp>("main"))
context.enqueue(main); context.enqueue(main);
for (auto func : module.getOps<FuncOp>()) context.enqueue(func); for (auto func : module.getOps<FuncOp>()) context.enqueue(func);

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
#include <atomic>
#include "absl/strings/str_split.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project
@ -23,17 +26,30 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
// Counter is used as a prefix for filenames.
static std::atomic<int> log_counter(0);
BridgeLoggerConfig::BridgeLoggerConfig(bool print_module_scope, BridgeLoggerConfig::BridgeLoggerConfig(bool print_module_scope,
bool print_after_only_on_change) bool print_after_only_on_change)
: mlir::PassManager::IRPrinterConfig(print_module_scope, : mlir::PassManager::IRPrinterConfig(print_module_scope,
print_after_only_on_change) {} print_after_only_on_change) {
const char* log_pass_patterns = getenv("MLIR_BRIDGE_LOG_PASS_PATTERNS");
if (log_pass_patterns) {
log_pass_patterns_ =
absl::StrSplit(log_pass_patterns, ',', absl::SkipWhitespace());
}
}
// Logs op to file with name of format `mlir_bridge-pass_name-file_suffix.mlir`. // Logs op to file with name of format
// `<log_counter>_mlir_bridge_<pass_name>_<file_suffix>.mlir`.
inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback, inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback,
mlir::Pass* pass, mlir::Operation* op, mlir::Pass* pass, mlir::Operation* op,
llvm::StringRef file_suffix) { llvm::StringRef file_suffix) {
std::string name = std::string pass_name = pass->getName().str();
llvm::formatv("mlir_bridge_{0}_{1}", pass->getName(), file_suffix).str();
// Add 4-digit counter as prefix so the order of the passes is obvious.
std::string name = llvm::formatv("{0,0+4}_mlir_bridge_{1}_{2}", log_counter++,
pass_name, file_suffix);
std::unique_ptr<llvm::raw_ostream> os; std::unique_ptr<llvm::raw_ostream> os;
std::string filepath; std::string filepath;
@ -44,13 +60,30 @@ inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback,
void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass, void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass,
mlir::Operation* operation, mlir::Operation* operation,
PrintCallbackFn print_callback) { PrintCallbackFn print_callback) {
Log(print_callback, pass, operation, "before"); if (should_print(pass)) Log(print_callback, pass, operation, "before");
} }
void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass, void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass,
mlir::Operation* operation, mlir::Operation* operation,
PrintCallbackFn print_callback) { PrintCallbackFn print_callback) {
Log(print_callback, pass, operation, "after"); if (should_print(pass)) Log(print_callback, pass, operation, "after");
}
bool BridgeLoggerConfig::should_print(mlir::Pass* pass) {
if (log_pass_patterns_.empty()) return true;
std::string pass_name = pass->getName().str();
for (const auto& pattern : log_pass_patterns_) {
if (pass_name.find(pattern) != std::string::npos) {
// pattern matches pass
return true;
}
}
// no pattern matches pass
VLOG(2) << "Not logging pass " << pass_name
<< " because it does not match any pattern in "
"MLIR_BRIDGE_LOG_PASS_PATTERNS";
return false;
} }
void BridgeTimingConfig::printTiming(PrintCallbackFn printCallback) { void BridgeTimingConfig::printTiming(PrintCallbackFn printCallback) {

View File

@ -23,7 +23,11 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
// Logger for logging/dumping MLIR modules before and after passes in bridge // Logger for logging/dumping MLIR modules before and after passes in bridge
// targeting TPUs. // targeting TPUs. The passes being logged can be restricted via environment
// variable `MLIR_BRIDGE_LOG_PASS_PATTERNS` which is interpreted as a comma-
// separated list of strings, and only passes whose name contains any of those
// strings as a substring are logged (no regex support). If
// `MLIR_BRIDGE_LOG_PASS_PATTERNS` is not defined, then all passes are logged.
class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig { class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig {
public: public:
explicit BridgeLoggerConfig(bool print_module_scope = false, explicit BridgeLoggerConfig(bool print_module_scope = false,
@ -42,6 +46,14 @@ class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig {
// with the stream to dump into. // with the stream to dump into.
void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation, void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
PrintCallbackFn print_callback) override; PrintCallbackFn print_callback) override;
private:
bool should_print(mlir::Pass *pass);
// Only print passes that match any of these patterns. A pass matches a
// pattern if its name contains the pattern as a substring. If
// `log_pass_patterns_` is empty, print all passes.
std::vector<std::string> log_pass_patterns_;
}; };
// Logger for logging/dumping pass pipeline timings after completion. // Logger for logging/dumping pass pipeline timings after completion.

View File

@ -516,6 +516,11 @@ class ConvertToHloModule {
// //
// TODO(hinsu): Check for dynamic shapes and exit instead of crashing. // TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
LogicalResult Run() { LogicalResult Run() {
auto main = module_.lookupSymbol<mlir::FuncOp>("main");
if (!main)
return module_.emitError(
"conversion requires module with `main` function");
for (auto func : module_.getOps<FuncOp>()) { for (auto func : module_.getOps<FuncOp>()) {
if (func.empty()) continue; if (func.empty()) continue;
if (failed(RunOnFunction(func))) return failure(); if (failed(RunOnFunction(func))) return failure();
@ -539,8 +544,11 @@ class ConvertToHloModule {
xla::XlaComputation* result); xla::XlaComputation* result);
::xla::HloModuleProto ConsumeMainProto() { ::xla::HloModuleProto ConsumeMainProto() {
return lowered_computation_[module_.lookupSymbol<mlir::FuncOp>("main")] auto main = module_.lookupSymbol<mlir::FuncOp>("main");
.proto(); // This is an invariant check as Run returns failure if there is no main
// function and so the main proto shouldn't be consumed in that case.
CHECK(main) << "requires module to have main function"; // Crash Ok.
return lowered_computation_[main].proto();
} }
// Lower function call to HLO call instruction // Lower function call to HLO call instruction

View File

@ -174,6 +174,13 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
return %0: tensor<4xi32> return %0: tensor<4xi32>
} }
// CHECK-LABEL: func @bitwise_xor
func @bitwise_xor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-NEXT: mhlo.xor
%0 = "tf.BitwiseXor"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0: tensor<4xi32>
}
// CHECK-LABEL: func @bitwise_and // CHECK-LABEL: func @bitwise_and
func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-NEXT: mhlo.and // CHECK-NEXT: mhlo.and

View File

@ -835,7 +835,14 @@ func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
} }
// CHECK-LABEL: func @floordiv_unranked // CHECK-LABEL: func @floordiv_unranked
func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NOT: tf.FloorDiv
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0: tensor<*xf32>
}
// CHECK-LABEL: func @floordiv_int
func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: tf.FloorDiv // CHECK: tf.FloorDiv
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0: tensor<*xi32> return %0: tensor<*xi32>
@ -894,7 +901,7 @@ func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
// CHECK-LABEL: func @floormod_unranked // CHECK-LABEL: func @floormod_unranked
func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: tf.FloorMod // CHECK-NOT: tf.FloorMod
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0: tensor<*xi32> return %0: tensor<*xi32>
} }

View File

@ -0,0 +1,7 @@
// RUN: not tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s 2>&1 | FileCheck %s
// CHECK: conversion requires module with `main`
func @non_main() {
%0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
return
}

View File

@ -114,7 +114,7 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>;
// Performs a substitution of FloorDiv, pseudo code below: // Performs a substitution of FloorDiv, pseudo code below:
// //
// return floor(div(x, y)) // return floor(div(x, y))
def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r),
(HLO_FloorOp (HLO_FloorOp
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))), (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))),
[(IEEEFloatTensor $l)]>; [(IEEEFloatTensor $l)]>;
@ -166,7 +166,7 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y // return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
// Requires static shaped inputs to create constant splats and computation of // Requires static shaped inputs to create constant splats and computation of
// broadcast attributes. // broadcast attributes.
def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r), def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r),
(HLO_SelectOp (HLO_SelectOp
(HLOClient_BroadcastAndOp (HLOClient_BroadcastAndOp
(HLOClient_BroadcastCompareOp (HLOClient_BroadcastCompareOp
@ -193,14 +193,15 @@ def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r),
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class DirectLogicalBinaryPat<Op FromOp, Op ToOp> class DirectLogicalBinaryPat<Op FromOp, Op ToOp>
: Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
(ToOp $l, $r, (BinBroadcastDimensions $l, $r)), (ToOp $l, $r, (BinBroadcastDimensions $l, $r)),
[(SignedIntTensor $l)]>; [(SignedIntTensor $l)]>;
foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp], foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp],
[TF_LogicalOrOp, HLOClient_BroadcastOrOp], [TF_LogicalOrOp, HLOClient_BroadcastOrOp],
[TF_BitwiseAndOp, HLOClient_BroadcastAndOp],
[TF_BitwiseOrOp, HLOClient_BroadcastOrOp], [TF_BitwiseOrOp, HLOClient_BroadcastOrOp],
[TF_BitwiseAndOp, HLOClient_BroadcastAndOp]] in [TF_BitwiseXorOp, HLOClient_BroadcastXorOp]] in
def : DirectLogicalBinaryPat<fromToBinPair[0], fromToBinPair[1]>; def : DirectLogicalBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -154,9 +154,11 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::IgammaOp>(), TypeID::get<TF::IgammaOp>(),
TypeID::get<TF::IgammacOp>(), TypeID::get<TF::IgammacOp>(),
TypeID::get<TF::IgammaGradAOp>(), TypeID::get<TF::IgammaGradAOp>(),
TypeID::get<TF::InplaceAddOp>(),
TypeID::get<TF::InTopKV2Op>(), TypeID::get<TF::InTopKV2Op>(),
TypeID::get<TF::InvertOp>(), TypeID::get<TF::InvertOp>(),
TypeID::get<TF::InvOp>(), TypeID::get<TF::InvOp>(),
TypeID::get<TF::KthOrderStatisticOp>(),
TypeID::get<TF::LRNOp>(), TypeID::get<TF::LRNOp>(),
TypeID::get<TF::LRNGradOp>(), TypeID::get<TF::LRNGradOp>(),
TypeID::get<TF::LeakyReluGradOp>(), TypeID::get<TF::LeakyReluGradOp>(),
@ -170,6 +172,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::LogicalOrOp>(), TypeID::get<TF::LogicalOrOp>(),
TypeID::get<TF::LogOp>(), TypeID::get<TF::LogOp>(),
TypeID::get<TF::LowerBoundOp>(), TypeID::get<TF::LowerBoundOp>(),
TypeID::get<TF::MakeUniqueOp>(),
TypeID::get<TF::MatMulOp>(), TypeID::get<TF::MatMulOp>(),
TypeID::get<TF::MatrixDiagV3Op>(), TypeID::get<TF::MatrixDiagV3Op>(),
TypeID::get<TF::MatrixInverseOp>(), TypeID::get<TF::MatrixInverseOp>(),
@ -248,6 +251,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::TensorScatterAddOp>(), TypeID::get<TF::TensorScatterAddOp>(),
TypeID::get<TF::TensorScatterSubOp>(), TypeID::get<TF::TensorScatterSubOp>(),
TypeID::get<TF::TPUEmbeddingActivationsOp>(), TypeID::get<TF::TPUEmbeddingActivationsOp>(),
TypeID::get<TF::TopKUniqueOp>(),
TypeID::get<TF::TopKWithUniqueOp>(),
TypeID::get<TF::TransposeOp>(), TypeID::get<TF::TransposeOp>(),
TypeID::get<TF::TridiagonalSolveOp>(), TypeID::get<TF::TridiagonalSolveOp>(),
TypeID::get<TF::TruncateDivOp>(), TypeID::get<TF::TruncateDivOp>(),

View File

@ -121,7 +121,7 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
// Run all HLO passes to produce an optimized module. // Run all HLO passes to produce an optimized module.
auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement( auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement(
std::move(hlo_module), backend->default_stream_executor(), std::move(hlo_module), backend->default_stream_executor(),
backend->memory_allocator(), optimize_xla_hlo); optimize_xla_hlo, {backend->memory_allocator()});
TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(), TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(),
"running XLA pass pipeline"); "running XLA pass pipeline");
std::unique_ptr<HloModule> optimized_hlo_module = std::unique_ptr<HloModule> optimized_hlo_module =

View File

@ -115,6 +115,16 @@ class ExecutableBuildOptions {
return *this; return *this;
} }
// Thread pool for parallel compilation.
tensorflow::thread::ThreadPool* compile_thread_pool() const {
return compile_thread_pool_;
}
ExecutableBuildOptions& set_run_backend_only(
tensorflow::thread::ThreadPool* compile_thread_pool) {
compile_thread_pool_ = compile_thread_pool;
return *this;
}
private: private:
int device_ordinal_ = -1; int device_ordinal_ = -1;
Shape result_layout_; Shape result_layout_;
@ -128,6 +138,7 @@ class ExecutableBuildOptions {
absl::optional<DeviceAssignment> device_assignment_; absl::optional<DeviceAssignment> device_assignment_;
bool alias_passthrough_params_ = false; bool alias_passthrough_params_ = false;
bool run_backend_only_ = false; bool run_backend_only_ = false;
tensorflow::thread::ThreadPool* compile_thread_pool_ = nullptr;
}; };
} // namespace xla } // namespace xla

View File

@ -3347,6 +3347,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
// contant False if dimension is static. // contant False if dimension is static.
// - Reduce: Convert to reduce or. // - Reduce: Convert to reduce or.
// - Constant: Convert to constant False. // - Constant: Convert to constant False.
// - Reshape, slice, transpose, pad:
// Convert into predicate type with same opcode.
// - Other ops: Not supported. // - Other ops: Not supported.
// Create the instruction for the new handle. // Create the instruction for the new handle.
TF_ASSIGN_OR_RETURN(HloOpcode opcode, TF_ASSIGN_OR_RETURN(HloOpcode opcode,
@ -3449,6 +3451,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
case HloOpcode::kBroadcast: case HloOpcode::kBroadcast:
case HloOpcode::kConcatenate: case HloOpcode::kConcatenate:
case HloOpcode::kReshape: case HloOpcode::kReshape:
case HloOpcode::kPad:
break; break;
case HloOpcode::kGetDimensionSize: { case HloOpcode::kGetDimensionSize: {
int64 dimension = instr_proto->dimensions(0); int64 dimension = instr_proto->dimensions(0);

View File

@ -26,8 +26,8 @@ static const char kCpuPlatformName[] = "cpu";
CpuDevice::CpuDevice(int id, CpuDevice::CpuDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state) std::unique_ptr<LocalDeviceState> local_device_state)
: PjRtDevice(id, std::move(local_device_state), : PjRtStreamExecutorDevice(id, std::move(local_device_state),
/*device_kind=*/kCpuPlatformName) {} /*device_kind=*/kCpuPlatformName) {}
StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) { StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
TF_ASSIGN_OR_RETURN(se::Platform * platform, TF_ASSIGN_OR_RETURN(se::Platform * platform,
@ -40,7 +40,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
TF_ASSIGN_OR_RETURN(LocalClient * client, TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options)); ClientLibrary::GetOrCreateLocalClient(options));
std::vector<std::unique_ptr<PjRtDevice>> devices; std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
for (int i = 0; i < client->device_count(); ++i) { for (int i = 0; i < client->device_count(); ++i) {
se::StreamExecutorConfig config; se::StreamExecutorConfig config;
config.ordinal = i; config.ordinal = i;
@ -57,11 +57,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
devices.push_back(std::move(device)); devices.push_back(std::move(device));
} }
return std::make_unique<PjRtClient>( return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
kCpuName, client, std::move(devices), /*host_id=*/0, kCpuName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false, /*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr); /*gpu_run_options=*/nullptr));
} }
} // namespace xla } // namespace xla

View File

@ -23,7 +23,7 @@ limitations under the License.
namespace xla { namespace xla {
class CpuDevice : public PjRtDevice { class CpuDevice : public PjRtStreamExecutorDevice {
public: public:
CpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state); CpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state);
}; };

View File

@ -35,9 +35,9 @@ namespace xla {
namespace { namespace {
// A custom PjRtClient that overrides the device assignment method. // A custom PjRtClient that overrides the device assignment method.
class GpuClient : public xla::PjRtClient { class GpuClient : public xla::PjRtStreamExecutorClient {
public: public:
using xla::PjRtClient::PjRtClient; using xla::PjRtStreamExecutorClient::PjRtStreamExecutorClient;
xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment( xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override; int num_replicas, int num_partitions) const override;
@ -55,7 +55,8 @@ xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(
return assignment; return assignment;
} }
// Fallback to default global device assignment if we can't run locally. // Fallback to default global device assignment if we can't run locally.
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions); return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
num_partitions);
} }
// Builds an xla::LocalClient for the GPU platform. // Builds an xla::LocalClient for the GPU platform.
@ -225,9 +226,9 @@ StatusOr<std::string> NcclIdStore::GetNcclUniqueId(
return result.first->second; return result.first->second;
} }
std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices( std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) { std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
std::vector<std::unique_ptr<PjRtDevice>> devices; std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
for (auto& local_device : local_device_states) { for (auto& local_device : local_device_states) {
int device_ordinal = local_device->device_ordinal(); int device_ordinal = local_device->device_ordinal();
const se::DeviceDescription& description = const se::DeviceDescription& description =
@ -243,7 +244,7 @@ std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
Status BuildDistributedDevices( Status BuildDistributedDevices(
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states, std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id, std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
std::vector<std::unique_ptr<PjRtDevice>>* devices, std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
gpu::GpuExecutableRunOptions* gpu_executable_run_options) { gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
LocalTopologyProto local_topology; LocalTopologyProto local_topology;
local_topology.set_node_id(node_id); local_topology.set_node_id(node_id);
@ -306,8 +307,8 @@ Status BuildDistributedDevices(
GpuDevice::GpuDevice(int id, GpuDevice::GpuDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state, std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int node_id) std::string device_kind, int node_id)
: PjRtDevice(id, std::move(local_device_state), std::move(device_kind), : PjRtStreamExecutorDevice(id, std::move(local_device_state),
node_id) {} std::move(device_kind), node_id) {}
StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient( StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
bool asynchronous, const GpuAllocatorConfig& allocator_config, bool asynchronous, const GpuAllocatorConfig& allocator_config,
@ -322,7 +323,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
auto host_memory_allocator = auto host_memory_allocator =
GetGpuHostAllocator(local_device_states.front()->executor()); GetGpuHostAllocator(local_device_states.front()->executor());
std::vector<std::unique_ptr<PjRtDevice>> devices; std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>(); auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
if (distributed_client) { if (distributed_client) {
TF_RETURN_IF_ERROR(BuildDistributedDevices( TF_RETURN_IF_ERROR(BuildDistributedDevices(

View File

@ -25,7 +25,7 @@ limitations under the License.
namespace xla { namespace xla {
class GpuDevice : public PjRtDevice { class GpuDevice : public PjRtStreamExecutorDevice {
public: public:
GpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state, GpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int node_id); std::string device_kind, int node_id);

View File

@ -26,8 +26,8 @@ static const char kInterpreterPlatformName[] = "interpreter";
InterpreterDevice::InterpreterDevice( InterpreterDevice::InterpreterDevice(
int id, std::unique_ptr<LocalDeviceState> local_device_state) int id, std::unique_ptr<LocalDeviceState> local_device_state)
: PjRtDevice(id, std::move(local_device_state), : PjRtStreamExecutorDevice(id, std::move(local_device_state),
/*device_kind=*/kInterpreterPlatformName) {} /*device_kind=*/kInterpreterPlatformName) {}
StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() { StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
TF_ASSIGN_OR_RETURN(se::Platform * platform, TF_ASSIGN_OR_RETURN(se::Platform * platform,
@ -41,7 +41,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
TF_ASSIGN_OR_RETURN(LocalClient * client, TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options)); ClientLibrary::GetOrCreateLocalClient(options));
std::vector<std::unique_ptr<PjRtDevice>> devices; std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
se::StreamExecutor* executor = se::StreamExecutor* executor =
client->backend().stream_executor(0).ValueOrDie(); client->backend().stream_executor(0).ValueOrDie();
auto device_state = absl::make_unique<LocalDeviceState>( auto device_state = absl::make_unique<LocalDeviceState>(
@ -51,11 +51,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
absl::make_unique<InterpreterDevice>(0, std::move(device_state)); absl::make_unique<InterpreterDevice>(0, std::move(device_state));
devices.push_back(std::move(device)); devices.push_back(std::move(device));
return std::make_unique<PjRtClient>( return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
"interpreter", client, std::move(devices), /*host_id=*/0, "interpreter", client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false, /*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr); /*gpu_run_options=*/nullptr));
} }
} // namespace xla } // namespace xla

View File

@ -23,7 +23,7 @@ limitations under the License.
namespace xla { namespace xla {
class InterpreterDevice : public PjRtDevice { class InterpreterDevice : public PjRtStreamExecutorDevice {
public: public:
InterpreterDevice(int id, InterpreterDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state); std::unique_ptr<LocalDeviceState> local_device_state);

View File

@ -114,21 +114,22 @@ limitations under the License.
namespace xla { namespace xla {
PjRtPlatformId PjRtDevice::platform_id() const { PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
return client_->platform_id(); return client_->platform_id();
} }
const std::string& PjRtDevice::platform_name() const { const std::string& PjRtStreamExecutorDevice::platform_name() const {
return client_->platform_name(); return client_->platform_name();
} }
StatusOr<LocalDeviceState*> PjRtDevice::GetLocalDeviceState() const { StatusOr<LocalDeviceState*> PjRtStreamExecutorDevice::GetLocalDeviceState()
const {
if (local_device_state_) { if (local_device_state_) {
return local_device_state_.get(); return local_device_state_.get();
} }
return InvalidArgument("Device %s is not a local device.", DebugString()); return InvalidArgument("Device %s is not a local device.", DebugString());
} }
std::string PjRtDevice::DebugString() const { std::string PjRtStreamExecutorDevice::DebugString() const {
return absl::StrCat(platform_name(), ":", id()); return absl::StrCat(platform_name(), ":", id());
} }
@ -153,14 +154,15 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
devices[replica].size(), replica, devices[0].size()); devices[replica].size(), replica, devices[0].size());
} }
for (int partition = 0; partition < devices[replica].size(); ++partition) { for (int partition = 0; partition < devices[replica].size(); ++partition) {
if (devices[0][0]->platform_id() != if (devices[0][0]->client()->platform_id() !=
devices[replica][partition]->platform_id()) { devices[replica][partition]->client()->platform_id()) {
return InvalidArgument( return InvalidArgument(
"Device assignment passed to Compile() must have devices of a " "Device assignment passed to Compile() must have devices of a "
"single kind, got %s for replica 0 partition 0 and %s for replica " "single kind, got %s for replica 0 partition 0 and %s for replica "
"%d partition %d.", "%d partition %d.",
devices[0][0]->platform_name(), devices[0][0]->client()->platform_name(),
devices[replica][partition]->platform_name(), replica, partition); devices[replica][partition]->client()->platform_name(), replica,
partition);
} }
xla_assignment(replica, partition) = devices[replica][partition]->id(); xla_assignment(replica, partition) = devices[replica][partition]->id();
} }
@ -182,9 +184,9 @@ class CpuAllocator : public tensorflow::Allocator {
} }
}; };
PjRtClient::PjRtClient( PjRtStreamExecutorClient::PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client, std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id, std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator, std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers, bool should_stage_host_to_device_transfers,
@ -193,7 +195,7 @@ PjRtClient::PjRtClient(
platform_name_(std::move(platform_name)), platform_name_(std::move(platform_name)),
client_(client), client_(client),
host_memory_allocator_(std::move(host_memory_allocator)), host_memory_allocator_(std::move(host_memory_allocator)),
devices_(std::move(devices)), owned_devices_(std::move(devices)),
host_id_(host_id), host_id_(host_id),
owned_allocator_(std::move(allocator)), owned_allocator_(std::move(allocator)),
should_stage_host_to_device_transfers_( should_stage_host_to_device_transfers_(
@ -211,12 +213,14 @@ PjRtClient::PjRtClient(
host_memory_allocator_ = std::make_unique<CpuAllocator>(); host_memory_allocator_ = std::make_unique<CpuAllocator>();
} }
for (const std::unique_ptr<PjRtDevice>& device : devices_) { for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
owned_devices_) {
devices_.push_back(device.get());
CHECK(id_to_device_.insert({device->id(), device.get()}).second) CHECK(id_to_device_.insert({device->id(), device.get()}).second)
<< "Duplicate device id: " << device->id(); << "Duplicate device id: " << device->id();
if (device->IsLocalDevice()) { if (device->IsAddressable()) {
int idx = device->local_device_id(); int idx = device->local_hardware_id();
if (idx >= local_devices_.size()) { if (idx >= local_devices_.size()) {
local_devices_.resize(idx + 1); local_devices_.resize(idx + 1);
} }
@ -230,13 +234,14 @@ PjRtClient::PjRtClient(
} }
} }
StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment( StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const { int num_replicas, int num_partitions) const {
return client_->backend().computation_placer()->AssignDevices(num_replicas, return client_->backend().computation_placer()->AssignDevices(num_replicas,
num_partitions); num_partitions);
} }
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() { std::unique_ptr<HloCostAnalysis>
PjRtStreamExecutorClient::GetHloCostAnalysis() {
return absl::make_unique<HloCostAnalysis>( return absl::make_unique<HloCostAnalysis>(
client_->backend().compiler()->ShapeSizeBytesFunction()); client_->backend().compiler()->ShapeSizeBytesFunction());
} }
@ -346,12 +351,13 @@ StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
return InvalidArgument("Can't make a buffer from an empty tuple"); return InvalidArgument("Can't make a buffer from an empty tuple");
} }
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
TransferManager* transfer_manager = TransferManager* transfer_manager =
client->client()->backend().transfer_manager(); se_client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
ScopedShapedBuffer dst_buffer, transfer_manager->AllocateScopedShapedBuffer(
transfer_manager->AllocateScopedShapedBuffer( on_host_shape, se_client->allocator(),
on_host_shape, client->allocator(), local_device->device_ordinal())); local_device->device_ordinal()));
if (local_device->allocation_model() == if (local_device->allocation_model() ==
LocalDeviceState::kComputeSynchronized) { LocalDeviceState::kComputeSynchronized) {
if (copy_stream == nullptr) { if (copy_stream == nullptr) {
@ -543,18 +549,21 @@ void PjRtBuffer::ScopedHold::AddToInput(
bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; } bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; }
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer( StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::BufferFromHostBuffer(
const void* data, const Shape& shape, const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics, HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) { std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer"); tensorflow::profiler::TraceMe traceme(
VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString() "PjRtStreamExecutorClient::BufferFromHostBuffer");
<< " device: " << device->DebugString(); VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
<< shape.ToString() << " device: " << device->DebugString();
if (shape.IsTuple()) { if (shape.IsTuple()) {
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple"); return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
} }
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState()); tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->GetLocalDeviceState());
int64 size = ShapeUtil::ByteSizeOf(shape); int64 size = ShapeUtil::ByteSizeOf(shape);
TransferManager* transfer_manager = client()->backend().transfer_manager(); TransferManager* transfer_manager = client()->backend().transfer_manager();
@ -708,20 +717,23 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
return py_buffer; return py_buffer;
} }
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer( StatusOr<std::unique_ptr<PjRtBuffer>>
const Shape& shape, PjRtDevice* device) { PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
PjRtDevice* device) {
return CreateUninitializedBuffer(shape, device, nullptr); return CreateUninitializedBuffer(shape, device, nullptr);
} }
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer( StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device, const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event) { std::shared_ptr<BufferSequencingEvent> definition_event) {
tensorflow::profiler::TraceMe traceme( tensorflow::profiler::TraceMe traceme(
"PjRtClient::CreateUninitializedBuffer"); "PjRtStreamExecutorClient::CreateUninitializedBuffer");
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: " VLOG(2) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
<< shape.ToString() << " device: " << device->DebugString(); << shape.ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState()); tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->GetLocalDeviceState());
TransferManager* transfer_manager = client()->backend().transfer_manager(); TransferManager* transfer_manager = client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(Shape compact_shape, TF_ASSIGN_OR_RETURN(Shape compact_shape,
@ -733,13 +745,16 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
definition_event); definition_event);
} }
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral( StatusOr<std::unique_ptr<PjRtBuffer>>
const LiteralSlice& literal, PjRtDevice* device) { PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral"); PjRtDevice* device) {
VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: " tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorClient::BufferFromHostLiteral");
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
<< literal.shape().ToString() << " device: " << device->DebugString(); << literal.shape().ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState()); tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->GetLocalDeviceState());
TransferManager* transfer_manager = client()->backend().transfer_manager(); TransferManager* transfer_manager = client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
@ -792,7 +807,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
return py_buffer; return py_buffer;
} }
void PjRtClient::MakeCrossHostReceiveBuffers( void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device, absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) { PjRtCrossHostRecvNotifier&& notifier) {
if (shapes.empty()) { if (shapes.empty()) {
@ -801,7 +816,9 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
return; return;
} }
auto local_device_or = device->GetLocalDeviceState(); auto local_device_or =
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->GetLocalDeviceState();
if (!local_device_or.ok()) { if (!local_device_or.ok()) {
notifier(local_device_or.status()); notifier(local_device_or.status());
return; return;
@ -828,27 +845,30 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
} }
// Transfer the given literal to the infeed queue of the given local device. // Transfer the given literal to the infeed queue of the given local device.
Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const { Status PjRtStreamExecutorDevice::TransferToInfeed(
const LiteralSlice& literal) const {
// Only support infeed to local device. // Only support infeed to local device.
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferToInfeedLocal( return local_device->client()->TransferToInfeedLocal(
literal, local_device->device_ordinal()); literal, local_device->device_ordinal());
} }
StatusOr<Literal> PjRtDevice::TransferFromOutfeed(const Shape& shape) const { StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed(
const Shape& shape) const {
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferFromOutfeedLocal( return local_device->client()->TransferFromOutfeedLocal(
shape, local_device->device_ordinal()); shape, local_device->device_ordinal());
} }
StatusOr<PjRtDevice*> PjRtClient::LookupLocalDevice(int local_device_id) const { StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
int local_hardware_id) const {
for (auto* device : local_devices_) { for (auto* device : local_devices_) {
if (local_device_id == device->local_device_id()) { if (local_hardware_id == device->local_hardware_id()) {
return device; return device;
} }
} }
return InvalidArgument("No matching device found for local_device_id %d", return InvalidArgument("No matching device found for local_hardware_id %d",
local_device_id); local_hardware_id);
} }
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape, PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
@ -873,7 +893,8 @@ PjRtBuffer::~PjRtBuffer() {
} }
int64 PjRtBuffer::OnDeviceSizeInBytes() const { int64 PjRtBuffer::OnDeviceSizeInBytes() const {
return client_->client() return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->client()
->backend() ->backend()
.transfer_manager() .transfer_manager()
->GetByteSizeRequirement(on_device_shape_); ->GetByteSizeRequirement(on_device_shape_);
@ -919,7 +940,9 @@ StatusOr<std::shared_ptr<TrackedDeviceBuffer>> PjRtBuffer::Release(
// the final set of usage events. // the final set of usage events.
events = device_buffer->LockUseAndTransferUsageEvents(); events = device_buffer->LockUseAndTransferUsageEvents();
} }
LocalDeviceState* local_device_state = device_->local_device_state(); LocalDeviceState* local_device_state =
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state();
if (wait_for_operations_to_complete) { if (wait_for_operations_to_complete) {
// Block the host until all usage events have completed. Usage events // Block the host until all usage events have completed. Usage events
// dominate definition events, so this also waits for the buffer to be // dominate definition events, so this also waits for the buffer to be
@ -1080,7 +1103,9 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
} }
ScopedHold device_buffer(this, ScopedHold::kUsage); ScopedHold device_buffer(this, ScopedHold::kUsage);
std::shared_ptr<HostValue> host_value; std::shared_ptr<HostValue> host_value;
LocalDeviceState* local_device = device_->local_device_state(); LocalDeviceState* local_device =
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state();
se::Stream* stream = local_device->GetDeviceToHostStream(); se::Stream* stream = local_device->GetDeviceToHostStream();
const xla::Layout& host_layout = const xla::Layout& host_layout =
layout.has_value() ? layout.value() : on_host_shape_.layout(); layout.has_value() ? layout.value() : on_host_shape_.layout();
@ -1122,12 +1147,16 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
host_value->value = std::make_shared<Literal>(host_shape); host_value->value = std::make_shared<Literal>(host_shape);
ShapedBuffer shaped_buffer = ShapedBuffer shaped_buffer =
device_buffer->AsShapedBuffer(host_shape, on_device_shape_); device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
client_->client()->backend().transfer_manager()->TransferLiteralFromDevice( tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
stream, shaped_buffer, host_value->value.get(), ->client()
[host_value](Status done_status) { ->backend()
host_value->status = done_status; .transfer_manager()
host_value->ready.Notify(); ->TransferLiteralFromDevice(stream, shaped_buffer,
}); host_value->value.get(),
[host_value](Status done_status) {
host_value->status = done_status;
host_value->ready.Notify();
});
auto usage_event = std::make_shared<BufferSequencingEvent>(); auto usage_event = std::make_shared<BufferSequencingEvent>();
StatusOr<EventPool::Handle> event_or = StatusOr<EventPool::Handle> event_or =
@ -1156,7 +1185,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral( StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
const bool discard_cached_copy, absl::optional<xla::Layout> layout) { const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral"); tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::ToLiteral");
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value, TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
CopyToHostAsyncInternal(discard_cached_copy, layout)); CopyToHostAsyncInternal(discard_cached_copy, layout));
if (host_value == nullptr) { if (host_value == nullptr) {
@ -1241,8 +1270,9 @@ PjRtBuffer::CopyToDeviceHelper(
// StallStreamOnError only makes sure the destination device is ok, so // StallStreamOnError only makes sure the destination device is ok, so
// make sure that the src buffer remains valid until after any transfers // make sure that the src buffer remains valid until after any transfers
// have completed. // have completed.
device_->local_device_state()->ThenRelease(transfer_stream, tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
src_device_buffer); ->local_device_state()
->ThenRelease(transfer_stream, src_device_buffer);
} }
return copy_event_or.status(); return copy_event_or.status();
} }
@ -1265,14 +1295,20 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral()); TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
return dst_device->client()->BufferFromHostBuffer( return dst_device->client()->BufferFromHostBuffer(
literal->untyped_data(), literal->shape(), literal->untyped_data(), literal->shape(),
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device); PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy, nullptr,
dst_device);
} }
TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device, TF_ASSIGN_OR_RETURN(
dst_device->GetLocalDeviceState()); LocalDeviceState * dst_local_device,
tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
->GetLocalDeviceState());
LocalDeviceState* transfer_local_device = LocalDeviceState* transfer_local_device =
client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state() tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
: dst_local_device; ->EnqueueD2DTransfersOnSrcStream()
? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state()
: dst_local_device;
CHECK_EQ(dst_local_device->allocation_model(), CHECK_EQ(dst_local_device->allocation_model(),
transfer_local_device->allocation_model()); transfer_local_device->allocation_model());
@ -1310,7 +1346,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
// alternative is to ensure, before freeing the buffer, that the compute // alternative is to ensure, before freeing the buffer, that the compute
// stream is synchronized past the transfer, but it seems better to hold onto // stream is synchronized past the transfer, but it seems better to hold onto
// the buffer too long than to stall the compute stream. // the buffer too long than to stall the compute stream.
RecordUsage(std::move(src_device_buffer), device_->local_device_state(), RecordUsage(std::move(src_device_buffer),
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state(),
transfer_local_device, event, transfer_stream, transfer_local_device, event, transfer_stream,
/*prefer_to_retain_reference=*/true); /*prefer_to_retain_reference=*/true);
@ -1318,7 +1356,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
} }
Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) { Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) {
return client_->CopyToRemoteDevice(this, serialized_descriptor); return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->CopyToRemoteDevice(this, serialized_descriptor);
} }
Status PjRtBuffer::BlockHostUntilReady() { Status PjRtBuffer::BlockHostUntilReady() {
@ -1332,7 +1371,9 @@ Status PjRtBuffer::BlockHostUntilReady() {
} }
device_buffer = device_buffer_; device_buffer = device_buffer_;
} }
LocalDeviceState* local_device_state = device_->local_device_state(); LocalDeviceState* local_device_state =
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state();
std::unique_ptr<se::Stream> stream; std::unique_ptr<se::Stream> stream;
for (auto& event : device_buffer->definition_events()) { for (auto& event : device_buffer->definition_events()) {
if (!event->IsComplete()) { if (!event->IsComplete()) {
@ -1378,9 +1419,13 @@ StatusOr<TupleHandle> MakeTupleHelper(
Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes); Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes);
Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes); Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
se::DeviceMemoryAllocator* allocator = client->allocator(); se::DeviceMemoryAllocator* allocator =
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)->allocator();
TransferManager* transfer_manager = TransferManager* transfer_manager =
client->client()->backend().transfer_manager(); tensorflow::down_cast<PjRtStreamExecutorClient*>(client)
->client()
->backend()
.transfer_manager();
se::Stream* stream = local_device->host_to_device_stream(); se::Stream* stream = local_device->host_to_device_stream();
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
se::OwningDeviceMemory root_table_memory, se::OwningDeviceMemory root_table_memory,
@ -1444,14 +1489,6 @@ std::unique_ptr<PjRtBuffer> OutputBufferHelper(
/*prefer_to_retain_reference=*/false); /*prefer_to_retain_reference=*/false);
return pjrt_buffer; return pjrt_buffer;
} }
static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) {
auto it = client.id_to_device().find(device_id);
CHECK(it != client.id_to_device().end())
<< "Unknown device id: " << device_id;
return it->second;
}
} // namespace } // namespace
PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable( PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
@ -1459,7 +1496,8 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
bool parameter_is_tupled_arguments, bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment, std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<LogicalDeviceIds> addressable_device_logical_ids, std::vector<LogicalDeviceIds> addressable_device_logical_ids,
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client) std::vector<PjRtDevice*> addressable_devices,
PjRtStreamExecutorClient* client)
: client_(client), : client_(client),
device_assignment_(std::move(device_assignment)), device_assignment_(std::move(device_assignment)),
parameter_is_tupled_arguments_(parameter_is_tupled_arguments), parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
@ -1482,7 +1520,7 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n" VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
<< device_assignment_->ToString(); << device_assignment_->ToString();
CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString(); CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
CHECK_LE(addressable_devices_.size(), client_->local_device_count()) CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
<< "Inconsistent local device count."; << "Inconsistent local device count.";
num_partitions = device_assignment_->computation_count(); num_partitions = device_assignment_->computation_count();
} }
@ -1584,7 +1622,7 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
absl::Span<const PjRtBuffer::ScopedHold> device_buffers, absl::Span<const PjRtBuffer::ScopedHold> device_buffers,
absl::flat_hash_set<BufferSequencingEvent*>& events) const { absl::flat_hash_set<BufferSequencingEvent*>& events) const {
std::vector<ExecutionInput> execution_inputs; std::vector<ExecutionInput> execution_inputs;
LocalDeviceState* device_state = &client_->device_state(device_ordinal); LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
// Lift tuple_handle outside the conditional so that the event it returns is // Lift tuple_handle outside the conditional so that the event it returns is
// not destroyed until after the loop below that waits on events. // not destroyed until after the loop below that waits on events.
absl::optional<TupleHandle> tuple_handle; absl::optional<TupleHandle> tuple_handle;
@ -1607,8 +1645,10 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
execution_input.MutableBuffers()->begin(); execution_input.MutableBuffers()->begin();
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end = ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
execution_input.MutableBuffers()->end(); execution_input.MutableBuffers()->end();
device_buffers[i].AddToInput(&input_iterator, iterator_end, device_buffers[i].AddToInput(
&execution_input, client_->allocator()); &input_iterator, iterator_end, &execution_input,
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->allocator());
CHECK(input_iterator == iterator_end); CHECK(input_iterator == iterator_end);
} }
} }
@ -1628,8 +1668,10 @@ StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
int executable_idx, const RunId& run_id, const ExecuteOptions& options, int executable_idx, const RunId& run_id, const ExecuteOptions& options,
PjRtDevice* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers, PjRtDevice* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment) const { std::shared_ptr<DeviceAssignment> device_assignment) const {
int device_ordinal = device->local_device_state()->device_ordinal(); int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
LocalDeviceState* device_state = &client_->device_state(device_ordinal); ->local_device_state()
->device_ordinal();
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
tensorflow::profiler::TraceMeConsumer activity( tensorflow::profiler::TraceMeConsumer activity(
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt, "LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
run_id.ToInt()); run_id.ToInt());
@ -1765,7 +1807,7 @@ PjRtStreamExecutorExecutable::MakeOutputBuffers(
std::shared_ptr<BufferSequencingEvent> definition_event, std::shared_ptr<BufferSequencingEvent> definition_event,
PjRtDevice* device) const { PjRtDevice* device) const {
std::vector<std::unique_ptr<PjRtBuffer>> outputs; std::vector<std::unique_ptr<PjRtBuffer>> outputs;
LocalDeviceState* device_state = &client_->device_state(device_ordinal); LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) { if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
int tuple_count = result_buffer.on_host_shape().tuple_shapes_size(); int tuple_count = result_buffer.on_host_shape().tuple_shapes_size();
outputs.reserve(tuple_count); outputs.reserve(tuple_count);
@ -1802,7 +1844,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
if (device == nullptr) { if (device == nullptr) {
CHECK(device_assignment_ != nullptr); CHECK(device_assignment_ != nullptr);
const int device_id = (*device_assignment_)(replica, partition); const int device_id = (*device_assignment_)(replica, partition);
device = LookupDevice(*client_, device_id); TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
device_assignment = device_assignment_; device_assignment = device_assignment_;
} else { } else {
CHECK(device_assignment_ == nullptr); CHECK(device_assignment_ == nullptr);
@ -1814,7 +1856,9 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
} }
CHECK_EQ(device->host_id(), client_->host_id()); CHECK_EQ(device->host_id(), client_->host_id());
int device_ordinal = device->local_device_state()->device_ordinal(); int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->local_device_state()
->device_ordinal();
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute"); tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
VLOG(3) << "Replica " << replica << ", partition " << partition VLOG(3) << "Replica " << replica << ", partition " << partition
<< " mapped to device ordinal for execution: " << device_ordinal; << " mapped to device ordinal for execution: " << device_ordinal;
@ -1836,7 +1880,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
ScopedShapedBuffer result_buffer = ScopedShapedBuffer result_buffer =
result_buffer_or_status.ConsumeValueOrDie(); result_buffer_or_status.ConsumeValueOrDie();
LocalDeviceState* device_state = &client_->device_state(device_ordinal); LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
se::Stream* stream = device_state->compute_stream(); se::Stream* stream = device_state->compute_stream();
StatusOr<EventPool::Handle> event_or = StatusOr<EventPool::Handle> event_or =
device_state->event_pool().ThenAllocateAndRecordEvent(stream); device_state->event_pool().ThenAllocateAndRecordEvent(stream);
@ -1922,7 +1966,9 @@ PjRtStreamExecutorExecutable::Execute(
const int replica = addressable_device_logical_ids_[i].replica; const int replica = addressable_device_logical_ids_[i].replica;
const int partition = addressable_device_logical_ids_[i].partition; const int partition = addressable_device_logical_ids_[i].partition;
PjRtDevice* device = addressable_devices_[i]; PjRtDevice* device = addressable_devices_[i];
const LocalDeviceState& device_state = *device->local_device_state(); const LocalDeviceState& device_state =
*tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->local_device_state();
device_state.execute_thread()->Schedule([&, replica, partition, i] { device_state.execute_thread()->Schedule([&, replica, partition, i] {
results[i] = ExecuteHelper(argument_handles[i], replica, partition, results[i] = ExecuteHelper(argument_handles[i], replica, partition,
run_id, options); run_id, options);
@ -2131,9 +2177,9 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
} // namespace } // namespace
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile( StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
const XlaComputation& computation, CompileOptions options) { const XlaComputation& computation, CompileOptions options) {
tensorflow::profiler::TraceMe traceme("PjRtClient::Compile"); tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
ExecutableBuildOptions& build_options = options.executable_build_options; ExecutableBuildOptions& build_options = options.executable_build_options;
if (!build_options.device_allocator()) { if (!build_options.device_allocator()) {
@ -2153,14 +2199,15 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
num_partitions = 1; num_partitions = 1;
} else { } else {
if (!build_options.has_device_assignment()) { if (!build_options.has_device_assignment()) {
VLOG(2) << "PjRtClient::Compile using default device_assignment."; VLOG(2) << "PjRtStreamExecutorClient::Compile using default "
"device_assignment.";
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment, DeviceAssignment device_assignment,
GetDefaultDeviceAssignment(build_options.num_replicas(), GetDefaultDeviceAssignment(build_options.num_replicas(),
build_options.num_partitions())); build_options.num_partitions()));
build_options.set_device_assignment(device_assignment); build_options.set_device_assignment(device_assignment);
} }
VLOG(2) << "PjRtClient::Compile device_assignment:\n" VLOG(2) << "PjRtStreamExecutorClient::Compile device_assignment:\n"
<< build_options.device_assignment().ToString(); << build_options.device_assignment().ToString();
num_replicas = build_options.device_assignment().replica_count(); num_replicas = build_options.device_assignment().replica_count();
num_partitions = build_options.device_assignment().computation_count(); num_partitions = build_options.device_assignment().computation_count();
@ -2234,7 +2281,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
for (int replica = 0; replica < num_replicas; ++replica) { for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) { for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = (*device_assignment)(replica, partition); int device_id = (*device_assignment)(replica, partition);
PjRtDevice* device = LookupDevice(*this, device_id); TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
if (device->host_id() != host_id()) { if (device->host_id() != host_id()) {
VLOG(3) << "Non-local device: " << device_id; VLOG(3) << "Non-local device: " << device_id;
continue; continue;
@ -2254,7 +2301,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
if (build_options.device_ordinal() < 0) { if (build_options.device_ordinal() < 0) {
build_options.set_device_ordinal( build_options.set_device_ordinal(
addressable_devices.front()->local_device_state()->device_ordinal()); addressable_devices.front()->local_hardware_id());
} }
} }

View File

@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -67,16 +68,56 @@ class PjRtClient;
class PjRtDevice { class PjRtDevice {
public: public:
explicit PjRtDevice(int id, virtual ~PjRtDevice() {}
std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int host_id = 0) // Return the client that owns this device.
virtual PjRtClient* client() const = 0;
// Whether client can issue command to this device.
virtual bool IsAddressable() const = 0;
// The ID of this device. IDs are unique among devices of this type
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
virtual int id() const = 0;
// The task ID of this device according to TpuTopology. This is not the same
// as PjRtClient::host_id() in a multi-task setting, where each client can see
// devices from all tasks, but only a subset of them are addressable and have
// the same task_id as the client.
virtual int host_id() const = 0;
// Opaque hardware ID, e.g., the CUDA device number, useful for identifying
// which GPU when interacting with non-JAX code. In general, not guaranteed to
// be dense, and -1 if undefined.
virtual int local_hardware_id() const = 0;
// A vendor-dependent string that uniquely identifies the kind of device,
// e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are
// compatible compilation.
virtual const std::string& device_kind() const = 0;
virtual std::string DebugString() const = 0;
// Transfer the given literal to the infeed queue.
virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
// Transfer and return a value of the given shape from the outfeed queue.
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0;
};
class PjRtStreamExecutorDevice : public PjRtDevice {
public:
explicit PjRtStreamExecutorDevice(
int id, std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int host_id = 0)
: id_(id), : id_(id),
local_device_id_( device_ordinal_(
local_device_state ? local_device_state->device_ordinal() : -1), local_device_state ? local_device_state->device_ordinal() : -1),
local_device_state_(std::move(local_device_state)), local_device_state_(std::move(local_device_state)),
host_id_(host_id), host_id_(host_id),
device_kind_(std::move(device_kind)) {} device_kind_(std::move(device_kind)) {}
virtual ~PjRtDevice() {} ~PjRtStreamExecutorDevice() override {}
// Must set client exactly once. // Must set client exactly once.
void SetClient(PjRtClient* client) { void SetClient(PjRtClient* client) {
@ -84,14 +125,25 @@ class PjRtDevice {
client_ = client; client_ = client;
} }
// Task ID. This is always 0 on single-task setup.
int host_id() const override { return host_id_; }
// Return `platform_id` from client.
PjRtPlatformId platform_id() const;
// Return `platform_name` from client.
const std::string& platform_name() const;
PjRtClient* client() const override { return client_; }
// The ID of this device. IDs are unique among devices of this type // The ID of this device. IDs are unique among devices of this type
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
// hosts' devices. This is the ID that should be used in a DeviceAssignment. // hosts' devices. This is the ID that should be used in a DeviceAssignment.
int id() const { return id_; } int id() const override { return id_; }
bool IsLocalDevice() const { return local_device_id_ != -1; } bool IsAddressable() const override { return device_ordinal_ != -1; }
int local_device_id() const { return local_device_id_; } int local_hardware_id() const override { return device_ordinal_; }
// If this is a device local to this host, returns a LocalDeviceState object // If this is a device local to this host, returns a LocalDeviceState object
// that can be used to manipulate the device. Returns nullptr if the device is // that can be used to manipulate the device. Returns nullptr if the device is
@ -105,32 +157,21 @@ class PjRtDevice {
// is not local to this host. // is not local to this host.
StatusOr<LocalDeviceState*> GetLocalDeviceState() const; StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
// The ID of this device's host. This is always 0 on single-host platforms.
int host_id() const { return host_id_; }
// Return `platform_id` from client.
PjRtPlatformId platform_id() const;
// Return `platform_name` from client.
const std::string& platform_name() const;
// A vendor-dependent string that uniquely identifies the kind of device. // A vendor-dependent string that uniquely identifies the kind of device.
const std::string& device_kind() const { return device_kind_; } const std::string& device_kind() const override { return device_kind_; }
virtual std::string DebugString() const; std::string DebugString() const override;
PjRtClient* client() const { return client_; }
// Transfer the given literal to the infeed queue of the given localdevice. // Transfer the given literal to the infeed queue of the given localdevice.
virtual Status TransferToInfeed(const LiteralSlice& literal) const; Status TransferToInfeed(const LiteralSlice& literal) const override;
// Transfer and return a value of the given shape from the outfeed of the // Transfer and return a value of the given shape from the outfeed of the
// given device. // given device.
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const; StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override;
private: private:
const int id_; const int id_;
const int local_device_id_; // -1 means not local. const int device_ordinal_; // -1 means not local.
const std::unique_ptr<LocalDeviceState> local_device_state_; const std::unique_ptr<LocalDeviceState> local_device_state_;
const int host_id_; const int host_id_;
const std::string device_kind_; const std::string device_kind_;
@ -178,86 +219,62 @@ class PjRtExecutable;
// alive as long as any of the other runtime objects are alive. // alive as long as any of the other runtime objects are alive.
class PjRtClient { class PjRtClient {
public: public:
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
virtual ~PjRtClient() = default; virtual ~PjRtClient() = default;
// TODO(zhangqiaorjc): Rename to task_id.
// Return the task id of this client. In single-task setting, always 0.
virtual int host_id() const = 0;
// Return the number of devices in the entire computation. In multi-headed
// client setting, some are addressable by this client, some are not. In a
// single-client setting, this is equal to the number of addressable devices.
virtual int device_count() const = 0;
// Return number of addressable devices. Addressable devices are those that
// the client can issue commands to.
virtual int addressable_device_count() const = 0;
// Return all devices in the entire computation, including addressable and
// non-addressable devices.
virtual absl::Span<PjRtDevice* const> devices() const = 0;
// TODO(zhangqiaorjc): Rename to addressable_devices.
// Return only addressable devices.
virtual absl::Span<PjRtDevice* const> local_devices() const = 0;
// Lookup any PjRtDevice for a given PjRtDevice::id().
virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
// Return an addressable PjRtDevice for a given
// PjRtDevice::local_hardware_id().
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
int local_hardware_id) const = 0;
// Return an ID that identifies the platform (CPU/GPU/TPU).
virtual PjRtPlatformId platform_id() const = 0;
// Returns a string that identifies the platform (CPU/GPU/TPU).
virtual const std::string& platform_name() const = 0;
// Return a device-specific default device assignment, e.g., GPU and TPU may
// be different.
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const; int num_replicas, int num_partitions) const = 0;
int device_count() const { return devices_.size(); }
int local_device_count() const { return local_devices_.size(); }
const std::vector<std::unique_ptr<PjRtDevice>>& devices() const {
return devices_;
}
const std::vector<PjRtDevice*>& local_devices() const {
return local_devices_;
}
const std::map<int, PjRtDevice*>& id_to_device() const {
return id_to_device_;
}
int host_id() const { return host_id_; }
PjRtPlatformId platform_id() const { return platform_id_; }
const std::string& platform_name() const { return platform_name_; }
LocalDeviceState& device_state(int device_ordinal) const {
return *local_devices_.at(device_ordinal)->local_device_state();
}
// Return a local PjRtDevice for a given `local_device_id`.
virtual StatusOr<PjRtDevice*> LookupLocalDevice(int local_device_id) const;
LocalClient* client() const { return client_; }
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get();
}
bool should_stage_host_to_device_transfers() const {
return should_stage_host_to_device_transfers_;
}
gpu::GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get();
}
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
return &h2d_transfer_pool_;
}
// Most platforms expect device-to-device transfers to be enqueued on the
// source d2d stream, but some platforms use the destination d2d stream. This
// function specifies which one the platform expects.
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
// Generates a unique fingerprint for `executable`.
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const {
return absl::optional<std::string>();
}
// Returns a backend-specific HLO cost analysis visitor. // Returns a backend-specific HLO cost analysis visitor.
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis(); virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() = 0;
// Compile `computation` with given `options`.
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile( virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options); const XlaComputation& computation, CompileOptions options) = 0;
// Generates a unique fingerprint for `executable`, may be absl::nullopt.
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const = 0;
// Creates a buffer on the device without initializing or copying any data. // Creates a buffer on the device without initializing or copying any data.
// An optional `definition_event` may be speficied that can be used to
// ensure the buffer isn't referenced until some external mechanism has
// initialized the data.
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
// future backends and so callers should avoid wherever possible.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device); const Shape& shape, PjRtDevice* device) = 0;
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event);
// Describes the semantics the caller to BufferFromHostBuffer expects from the // Describes the semantics the caller to BufferFromHostBuffer expects from the
// runtime, in a total order from most restrictive to least restrictive. // runtime, in a total order from most restrictive to least restrictive.
@ -289,13 +306,13 @@ class PjRtClient {
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer( virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape, const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics, HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device); std::shared_ptr<void> buffer_reference, PjRtDevice* device) = 0;
// Note that literal must remain in scope until the transfer has completed, so // Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on // the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope. // the return value before letting literal go out of scope.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral( virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device); const LiteralSlice& literal, PjRtDevice* device) = 0;
// Asynchronously makes a vector of PjRtBuffers that can be used to receive // Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact // cross host transfers using `client` on `device'. `shapes` must be the exact
@ -308,18 +325,140 @@ class PjRtClient {
// buffers will become ready until *all* of the sends have completed. // buffers will become ready until *all* of the sends have completed.
virtual void MakeCrossHostReceiveBuffers( virtual void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device, absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier); PjRtCrossHostRecvNotifier&& notifier) = 0;
virtual StatusOr<ChannelHandle> CreateChannelHandle() { // Create ChannelHandles for XLA send/recv.
virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
};
class PjRtStreamExecutorClient : public PjRtClient {
public:
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
~PjRtStreamExecutorClient() override = default;
int host_id() const override { return host_id_; }
int device_count() const override { return devices_.size(); }
int addressable_device_count() const override {
return local_devices_.size();
}
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
absl::Span<PjRtDevice* const> local_devices() const override {
return local_devices_;
}
StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
auto it = id_to_device_.find(device_id);
if (it != id_to_device_.end()) {
return it->second;
}
return InvalidArgument("No matching device found for device_id %d",
device_id);
}
StatusOr<PjRtDevice*> LookupAddressableDevice(
int local_hardware_id) const override;
PjRtPlatformId platform_id() const override { return platform_id_; }
const std::string& platform_name() const override { return platform_name_; }
// Most platforms expect device-to-device transfers to be enqueued on the
// source d2d stream, but some platforms use the destination d2d stream. This
// function specifies which one the platform expects.
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) override;
// Generates a unique fingerprint for `executable`.
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const override {
return absl::optional<std::string>();
}
// Returns a backend-specific HLO cost analysis visitor.
std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() override;
// Creates a buffer on the device without initializing or copying any data.
// An optional `definition_event` may be speficied that can be used to
// ensure the buffer isn't referenced until some external mechanism has
// initialized the data.
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
// future backends and so callers should avoid wherever possible.
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device) override;
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event);
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope.
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) override;
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
// shapes, with identical layouts, corresponding to the buffers that will be
// sent. When resources for the transfer are available, notifier will be
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
// shape in `shapes`. Each struct contains a buffer that will contain the
// received value, and an opaque string that should be transmitted to the
// sending host and used in a call to CopyToRemoteDevice. None of the recv
// buffers will become ready until *all* of the sends have completed.
void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) override;
StatusOr<ChannelHandle> CreateChannelHandle() override {
return client()->CreateChannelHandle(); return client()->CreateChannelHandle();
} }
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() { StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
return client()->CreateDeviceToHostChannelHandle(); return client()->CreateDeviceToHostChannelHandle();
} }
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() { StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
return client()->CreateHostToDeviceChannelHandle(); return client()->CreateHostToDeviceChannelHandle();
} }
LocalDeviceState& device_state(int device_ordinal) const {
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
local_devices_.at(device_ordinal))
->local_device_state();
}
LocalClient* client() const { return client_; }
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get();
}
bool should_stage_host_to_device_transfers() const {
return should_stage_host_to_device_transfers_;
}
gpu::GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get();
}
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
return &h2d_transfer_pool_;
}
protected: protected:
friend class PjRtBuffer; friend class PjRtBuffer;
virtual void EnqueueCrossHostReceive( virtual void EnqueueCrossHostReceive(
@ -342,7 +481,9 @@ class PjRtClient {
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_; std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
// Includes all devices, including non-local devices on multi-host platforms. // Includes all devices, including non-local devices on multi-host platforms.
std::vector<std::unique_ptr<PjRtDevice>> devices_; std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
// Pointers to `owned_devices_`.
std::vector<PjRtDevice*> devices_;
// Maps Device::id() to the corresponding Device. Includes all devices. // Maps Device::id() to the corresponding Device. Includes all devices.
std::map<int, PjRtDevice*> id_to_device_; std::map<int, PjRtDevice*> id_to_device_;
// Local devices indexed by local device ordinal. // Local devices indexed by local device ordinal.
@ -509,7 +650,7 @@ class PjRtBuffer {
private: private:
friend class PjRtBuffer; friend class PjRtBuffer;
friend class PjRtClient; friend class PjRtStreamExecutorClient;
// Helper struct that makes it possible to move a ScopedHold through a // Helper struct that makes it possible to move a ScopedHold through a
// closure. // closure.
@ -769,7 +910,7 @@ class PjRtExecutable {
virtual PjRtClient* client() const = 0; virtual PjRtClient* client() const = 0;
// Unique name for this executable, e.g., HloModule name. // Unique name for this executable, e.g., HloModule name.
virtual const string& name() const = 0; virtual const std::string& name() const = 0;
virtual int num_replicas() const = 0; virtual int num_replicas() const = 0;
@ -791,6 +932,7 @@ class PjRtExecutable {
virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids() virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
const = 0; const = 0;
// An addressable_device is one which the client can issue commands to.
// addressable_devices()[i] is the Device to which // addressable_devices()[i] is the Device to which
// addressable_device_logical_ids()[i] is assigned. // addressable_device_logical_ids()[i] is assigned.
virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0; virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
@ -833,13 +975,14 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
bool parameter_is_tupled_arguments, bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment, std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<LogicalDeviceIds> addressable_device_logical_ids, std::vector<LogicalDeviceIds> addressable_device_logical_ids,
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client); std::vector<PjRtDevice*> addressable_devices,
PjRtStreamExecutorClient* client);
~PjRtStreamExecutorExecutable() override = default; ~PjRtStreamExecutorExecutable() override = default;
PjRtClient* client() const override { return client_; } PjRtStreamExecutorClient* client() const override { return client_; }
const string& name() const override; const std::string& name() const override;
int num_replicas() const override { int num_replicas() const override {
return executables_[0]->build_options().num_replicas(); return executables_[0]->build_options().num_replicas();
@ -898,7 +1041,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
} }
private: private:
friend class PjRtClient; friend class PjRtStreamExecutorClient;
// Initializes information about which arguments to which executables must be // Initializes information about which arguments to which executables must be
// donated due to aliases that were specified by the computation. // donated due to aliases that were specified by the computation.
Status SetUpDonation(bool tuple_inputs); Status SetUpDonation(bool tuple_inputs);
@ -933,7 +1076,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
// Create shared pointers so we can free them after the execution: with // Create shared pointers so we can free them after the execution: with
// asynchronous execution, the process being executed can outlive the // asynchronous execution, the process being executed can outlive the
// executable itself. // executable itself.
PjRtClient* const client_; PjRtStreamExecutorClient* const client_;
// One executable per partition. // One executable per partition.
std::vector<std::shared_ptr<LocalExecutable>> executables_; std::vector<std::shared_ptr<LocalExecutable>> executables_;
// Per-executable set of parameters that have any aliased buffers and thus // Per-executable set of parameters that have any aliased buffers and thus

View File

@ -94,10 +94,11 @@ Status TpuDeviceState::ThenMemcpyDeviceToDevice(
return Status::OK(); return Status::OK();
} }
class PjRtTpuClient : public PjRtClient { class PjRtTpuClient : public PjRtStreamExecutorClient {
public: public:
PjRtTpuClient(LocalClient* client, PjRtTpuClient(LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id); std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
int host_id);
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override; int num_replicas, int num_partitions) const override;
@ -108,14 +109,14 @@ class PjRtTpuClient : public PjRtClient {
const PjRtExecutable& executable) const override; const PjRtExecutable& executable) const override;
}; };
PjRtTpuClient::PjRtTpuClient(LocalClient* client, PjRtTpuClient::PjRtTpuClient(
std::vector<std::unique_ptr<PjRtDevice>> devices, LocalClient* client,
int host_id) std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id)
: PjRtClient(kTpuName, client, std::move(devices), host_id, : PjRtStreamExecutorClient(kTpuName, client, std::move(devices), host_id,
/*allocator=*/nullptr, /*allocator=*/nullptr,
/*host_memory_allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false, /*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr) {} /*gpu_run_options=*/nullptr) {}
StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment( StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const { int num_replicas, int num_partitions) const {
@ -128,7 +129,8 @@ StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
num_partitions); num_partitions);
} }
// Fallback to default global device assignment if we can't run locally. // Fallback to default global device assignment if we can't run locally.
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions); return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
num_partitions);
} }
StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint( StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
@ -152,10 +154,10 @@ StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
return absl::optional<std::string>(tpu_executable->fingerprint()); return absl::optional<std::string>(tpu_executable->fingerprint());
} }
StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices( StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
LocalClient* client, LocalClient* client,
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) { std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
std::vector<std::unique_ptr<PjRtDevice>> devices; std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
tf_tpu::TpuTopologyExternal topology = tf_tpu::TpuTopologyExternal topology =
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology(); tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();

View File

@ -26,14 +26,14 @@ limitations under the License.
namespace xla { namespace xla {
class PjRtTpuDevice : public PjRtDevice { class PjRtTpuDevice : public PjRtStreamExecutorDevice {
public: public:
PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core, PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,
std::unique_ptr<LocalDeviceState> local_device_state, std::unique_ptr<LocalDeviceState> local_device_state,
int host_id, const std::array<int, 3>& coords, int host_id, const std::array<int, 3>& coords,
std::string device_kind) std::string device_kind)
: PjRtDevice(core.Id(), std::move(local_device_state), : PjRtStreamExecutorDevice(core.Id(), std::move(local_device_state),
std::move(device_kind), host_id), std::move(device_kind), host_id),
core_(core), core_(core),
coords_(coords) {} coords_(coords) {}

View File

@ -97,13 +97,13 @@ cc_library(
name = "types", name = "types",
srcs = ["types.cc"], srcs = ["types.cc"],
hdrs = ["types.h"], hdrs = ["types.h"],
compatible_with = [],
copts = [ copts = [
"-fexceptions", "-fexceptions",
"-fno-strict-aliasing", "-fno-strict-aliasing",
], ],
features = ["-use_header_modules"], features = ["-use_header_modules"],
deps = [ deps = [
":bfloat16",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status",
@ -113,6 +113,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/python:bfloat16_lib",
"//third_party/py/numpy:headers", "//third_party/py/numpy:headers",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
@ -158,42 +159,6 @@ cc_library(
], ],
) )
cc_library(
name = "bfloat16",
srcs = ["bfloat16.cc"],
hdrs = ["bfloat16.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core/platform:bfloat16",
"//tensorflow/core/platform:logging",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers", # buildcleaner: keep
"@com_google_absl//absl/strings",
"@pybind11",
],
)
py_test(
name = "bfloat16_test",
srcs = ["bfloat16_test.py"],
main = "bfloat16_test.py",
python_version = "PY3",
tags = ["no_oss"],
deps = [
":xla_client",
":xla_extension",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
] + xla_py_test_deps(),
)
cc_library( cc_library(
name = "py_client", name = "py_client",
srcs = [ srcs = [
@ -206,6 +171,7 @@ cc_library(
"py_client.h", "py_client.h",
"py_executable.h", "py_executable.h",
], ],
compatible_with = [],
copts = [ copts = [
"-fexceptions", "-fexceptions",
"-fno-strict-aliasing", "-fno-strict-aliasing",
@ -232,6 +198,7 @@ cc_library(
name = "dlpack", name = "dlpack",
srcs = ["dlpack.cc"], srcs = ["dlpack.cc"],
hdrs = ["dlpack.h"], hdrs = ["dlpack.h"],
compatible_with = [],
copts = [ copts = [
"-fexceptions", "-fexceptions",
"-fno-strict-aliasing", "-fno-strict-aliasing",
@ -263,6 +230,7 @@ cc_library(
name = "jax_jit", name = "jax_jit",
srcs = ["jax_jit.cc"], srcs = ["jax_jit.cc"],
hdrs = ["jax_jit.h"], hdrs = ["jax_jit.h"],
compatible_with = [],
copts = [ copts = [
"-fexceptions", "-fexceptions",
"-fno-strict-aliasing", "-fno-strict-aliasing",
@ -292,6 +260,7 @@ cc_library(
name = "ops", name = "ops",
srcs = ["ops.cc"], srcs = ["ops.cc"],
hdrs = ["ops.h"], hdrs = ["ops.h"],
compatible_with = [],
copts = [ copts = [
"-fexceptions", "-fexceptions",
"-fno-strict-aliasing", "-fno-strict-aliasing",
@ -356,6 +325,7 @@ cc_library(
name = "outfeed_receiver_py", name = "outfeed_receiver_py",
srcs = ["outfeed_receiver_py.cc"], srcs = ["outfeed_receiver_py.cc"],
hdrs = ["outfeed_receiver_py.h"], hdrs = ["outfeed_receiver_py.h"],
compatible_with = [],
copts = [ copts = [
"-fexceptions", "-fexceptions",
"-fno-strict-aliasing", "-fno-strict-aliasing",
@ -379,12 +349,14 @@ cc_library(
name = "pytree", name = "pytree",
srcs = ["pytree.cc"], srcs = ["pytree.cc"],
hdrs = ["pytree.h"], hdrs = ["pytree.h"],
compatible_with = [],
copts = [ copts = [
"-fexceptions", "-fexceptions",
"-fno-strict-aliasing", "-fno-strict-aliasing",
], ],
features = ["-use_header_modules"], features = ["-use_header_modules"],
deps = [ deps = [
":types",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash", "@com_google_absl//absl/hash",
@ -435,6 +407,7 @@ cc_library(
name = "xla_compiler", name = "xla_compiler",
srcs = ["xla_compiler.cc"], srcs = ["xla_compiler.cc"],
hdrs = ["xla_compiler.h"], hdrs = ["xla_compiler.h"],
compatible_with = [],
copts = [ copts = [
"-fexceptions", "-fexceptions",
"-fno-strict-aliasing", "-fno-strict-aliasing",
@ -481,7 +454,6 @@ pybind_extension(
features = ["-use_header_modules"], features = ["-use_header_modules"],
module_name = "xla_extension", module_name = "xla_extension",
deps = [ deps = [
":bfloat16",
":dlpack", ":dlpack",
":jax_jit", ":jax_jit",
":ops", ":ops",
@ -534,6 +506,7 @@ pybind_extension(
# without any TF dependencies as "jaxlib" on Pypi, and "jaxlib" does # without any TF dependencies as "jaxlib" on Pypi, and "jaxlib" does
# not require Tensorflow. # not require Tensorflow.
"//tensorflow/core:lib_internal_impl", # buildcleaner: keep "//tensorflow/core:lib_internal_impl", # buildcleaner: keep
"//tensorflow/python:bfloat16_lib",
"//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor:platform", "//tensorflow/stream_executor:platform",
] + select({ ] + select({

File diff suppressed because it is too large Load Diff

View File

@ -1,440 +0,0 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for the bfloat16 Python type."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import itertools
import math
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.xla.python import xla_client
bfloat16 = xla_client.bfloat16
def numpy_assert_allclose(a, b, **kwargs):
a = a.astype(np.float32) if a.dtype == bfloat16 else a
b = b.astype(np.float32) if b.dtype == bfloat16 else b
return np.testing.assert_allclose(a, b, **kwargs)
epsilon = float.fromhex("1.0p-7")
# Values that should round trip exactly to float and back.
FLOAT_VALUES = [
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
float("inf"),
float("-inf"),
float("nan")
]
class Bfloat16Test(parameterized.TestCase):
"""Tests the non-numpy Python methods of the bfloat16 type."""
def testRoundTripToFloat(self):
for v in FLOAT_VALUES:
np.testing.assert_equal(v, float(bfloat16(v)))
def testRoundTripNumpyTypes(self):
for dtype in [np.float16, np.float32, np.float64]:
np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75))))
np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5))))
np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype))))
np.testing.assert_equal(
np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype)))
def testRoundTripToInt(self):
for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]:
self.assertEqual(v, int(bfloat16(v)))
# pylint: disable=g-complex-comprehension
@parameterized.named_parameters(({
"testcase_name": "_" + dtype.__name__,
"dtype": dtype
} for dtype in [bfloat16, np.float16, np.float32, np.float64]))
def testRoundTripToNumpy(self, dtype):
for v in FLOAT_VALUES:
np.testing.assert_equal(v, bfloat16(dtype(v)))
np.testing.assert_equal(v, dtype(bfloat16(dtype(v))))
np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype))))
if dtype != bfloat16:
np.testing.assert_equal(
np.array(FLOAT_VALUES, dtype),
bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype))
def testStr(self):
self.assertEqual("0", str(bfloat16(0.0)))
self.assertEqual("1", str(bfloat16(1.0)))
self.assertEqual("-3.5", str(bfloat16(-3.5)))
self.assertEqual("0.0078125", str(bfloat16(float.fromhex("1.0p-7"))))
self.assertEqual("inf", str(bfloat16(float("inf"))))
self.assertEqual("-inf", str(bfloat16(float("-inf"))))
self.assertEqual("nan", str(bfloat16(float("nan"))))
def testRepr(self):
self.assertEqual("0", repr(bfloat16(0)))
self.assertEqual("1", repr(bfloat16(1)))
self.assertEqual("-3.5", repr(bfloat16(-3.5)))
self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7"))))
self.assertEqual("inf", repr(bfloat16(float("inf"))))
self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
self.assertEqual("nan", repr(bfloat16(float("nan"))))
def testHash(self):
self.assertEqual(0, hash(bfloat16(0.0)))
self.assertEqual(0x3f80, hash(bfloat16(1.0)))
self.assertEqual(0x7fc0, hash(bfloat16(float("nan"))))
# Tests for Python operations
def testNegate(self):
for v in FLOAT_VALUES:
np.testing.assert_equal(-v, float(-bfloat16(v)))
def testAdd(self):
np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0)))
np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0)))
np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1)))
np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5)))
np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan")))))
# Test type promotion against Numpy scalar values.
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float16(2.25)))
self.assertEqual(np.float32, type(np.float16(3.5) + bfloat16(2.25)))
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float32(2.25)))
self.assertEqual(np.float32, type(np.float32(3.5) + bfloat16(2.25)))
self.assertEqual(np.float64, type(bfloat16(3.5) + np.float64(2.25)))
self.assertEqual(np.float64, type(np.float64(3.5) + bfloat16(2.25)))
self.assertEqual(np.float64, type(bfloat16(3.5) + float(2.25)))
self.assertEqual(np.float64, type(float(3.5) + bfloat16(2.25)))
self.assertEqual(np.float32,
type(bfloat16(3.5) + np.array(2.25, np.float32)))
self.assertEqual(np.float32,
type(np.array(3.5, np.float32) + bfloat16(2.25)))
def testSub(self):
np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0)))
np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0)))
np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1)))
np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5)))
np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf"))))
np.testing.assert_equal(
float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf"))))
self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan")))))
def testMul(self):
np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0)))
np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0)))
np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1)))
np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan")))))
def testDiv(self):
self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0))))
np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0)))
np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1)))
np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
def testLess(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
def testLessEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
def testGreater(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
def testGreaterEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
def testEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
def testNotEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
def testNan(self):
a = np.isnan(bfloat16(float("nan")))
self.assertTrue(a)
numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
a = np.array([bfloat16(1.34375),
bfloat16(1.4375),
bfloat16(float("nan"))],
dtype=bfloat16)
b = np.array(
[bfloat16(1.3359375),
bfloat16(1.4375),
bfloat16(float("nan"))],
dtype=bfloat16)
numpy_assert_allclose(
a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
def testSort(self):
values_to_sort = np.float32(FLOAT_VALUES)
sorted_f32 = np.sort(values_to_sort)
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
BinaryOp = collections.namedtuple("BinaryOp", ["op"])
UNARY_UFUNCS = [
np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign,
np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p,
np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan,
np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh,
np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc
]
BINARY_UFUNCS = [
np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2,
np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2,
np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign
]
BINARY_PREDICATE_UFUNCS = [
np.equal, np.not_equal, np.less, np.greater, np.less_equal,
np.greater_equal, np.logical_and, np.logical_or, np.logical_xor
]
class Bfloat16NumPyTest(parameterized.TestCase):
"""Tests the NumPy integration of the bfloat16 type."""
def testDtype(self):
self.assertEqual(bfloat16, np.dtype(bfloat16))
def testDeepCopyDoesNotAlterHash(self):
# For context, see https://github.com/google/jax/issues/4651. If the hash
# value of the type descriptor is not initialized correctly, a deep copy
# can change the type hash.
dtype = np.dtype(bfloat16)
h = hash(dtype)
_ = copy.deepcopy(dtype)
self.assertEqual(h, hash(dtype))
def testArray(self):
x = np.array([[1, 2, 3]], dtype=bfloat16)
self.assertEqual(bfloat16, x.dtype)
self.assertEqual("[[1 2 3]]", str(x))
np.testing.assert_equal(x, x)
numpy_assert_allclose(x, x)
self.assertTrue((x == x).all())
def testComparisons(self):
x = np.array([401408, 7, -32], dtype=np.float32)
bx = x.astype(bfloat16)
y = np.array([82432, 7, 0], dtype=np.float32)
by = y.astype(bfloat16)
np.testing.assert_equal(x == y, bx == by)
np.testing.assert_equal(x != y, bx != by)
np.testing.assert_equal(x < y, bx < by)
np.testing.assert_equal(x > y, bx > by)
np.testing.assert_equal(x <= y, bx <= by)
np.testing.assert_equal(x >= y, bx >= by)
def testEqual2(self):
a = np.array([401408], bfloat16)
b = np.array([82432], bfloat16)
self.assertFalse(a.__eq__(b))
def testCasts(self):
for dtype in [
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
]:
x = np.array([[1, 2, 3]], dtype=dtype)
y = x.astype(bfloat16)
z = y.astype(dtype)
self.assertTrue(np.all(x == y))
self.assertEqual(bfloat16, y.dtype)
self.assertTrue(np.all(x == z))
self.assertEqual(dtype, z.dtype)
def testConformNumpyComplex(self):
for dtype in [np.complex64, np.complex128]:
x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype)
y_np = x.astype(np.float32)
y_tf = x.astype(bfloat16)
numpy_assert_allclose(y_np, y_tf, atol=2e-2)
z_np = y_np.astype(dtype)
z_tf = y_tf.astype(dtype)
numpy_assert_allclose(z_np, z_tf, atol=2e-2)
def testArange(self):
np.testing.assert_equal(
np.arange(100, dtype=np.float32).astype(bfloat16),
np.arange(100, dtype=bfloat16))
np.testing.assert_equal(
np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
np.testing.assert_equal(
np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
np.arange(-0., -7., -0.25, dtype=bfloat16))
np.testing.assert_equal(
np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
np.arange(-16384., 16384., 64., dtype=bfloat16))
# pylint: disable=g-complex-comprehension
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in UNARY_UFUNCS))
def testUnaryUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7, 10).astype(bfloat16)
numpy_assert_allclose(
op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in BINARY_UFUNCS))
def testBinaryUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7, 10).astype(bfloat16)
y = rng.randn(4, 1, 7, 10).astype(bfloat16)
numpy_assert_allclose(
op(x, y).astype(np.float32),
op(x.astype(np.float32), y.astype(np.float32)),
rtol=1e-2)
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in BINARY_PREDICATE_UFUNCS))
def testBinaryPredicateUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randn(4, 1, 7).astype(bfloat16)
np.testing.assert_equal(
op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
def testPredicateUfunc(self, op):
rng = np.random.RandomState(seed=42)
shape = (3, 7, 10)
posinf_flips = rng.rand(*shape) < 0.1
neginf_flips = rng.rand(*shape) < 0.1
nan_flips = rng.rand(*shape) < 0.1
vals = rng.randn(*shape)
vals = np.where(posinf_flips, np.inf, vals)
vals = np.where(neginf_flips, -np.inf, vals)
vals = np.where(nan_flips, np.nan, vals)
vals = vals.astype(bfloat16)
np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
def testDivmod(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randn(4, 1, 7).astype(bfloat16)
o1, o2 = np.divmod(x, y)
e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32))
numpy_assert_allclose(o1, e1, rtol=1e-2)
numpy_assert_allclose(o2, e2, rtol=1e-2)
def testModf(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
o1, o2 = np.modf(x)
e1, e2 = np.modf(x.astype(np.float32))
numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2)
numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2)
def testLdexp(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randint(-50, 50, (1, 7))
numpy_assert_allclose(
np.ldexp(x, y).astype(np.float32),
np.ldexp(x.astype(np.float32), y),
rtol=1e-2,
atol=1e-6)
def testFrexp(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
mant1, exp1 = np.frexp(x)
mant2, exp2 = np.frexp(x.astype(np.float32))
np.testing.assert_equal(exp1, exp2)
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
def testNextAfter(self):
one = np.array(1., dtype=bfloat16)
two = np.array(2., dtype=bfloat16)
zero = np.array(0., dtype=bfloat16)
nan = np.array(np.nan, dtype=bfloat16)
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
np.testing.assert_equal(np.nextafter(one, one), one)
smallest_denormal = float.fromhex("1.0p-133")
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
for a, b in itertools.permutations([0., -0., nan], 2):
np.testing.assert_equal(
np.nextafter(
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
np.nextafter(
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
if __name__ == "__main__":
absltest.main()

View File

@ -214,11 +214,9 @@ StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
} }
StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) { StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
const se::Platform* platform = if (device.client()->platform_id() == kCpuId) {
device.local_device_state()->executor()->platform();
if (platform->id() == se::host::kHostPlatformId) {
return kDLCPU; return kDLCPU;
} else if (platform->id() == se::cuda::kCudaPlatformId) { } else if (device.client()->platform_id() == kGpuId) {
return kDLGPU; return kDLGPU;
} }
return InvalidArgument("Device %s cannot be used as a DLPack device.", return InvalidArgument("Device %s cannot be used as a DLPack device.",
@ -228,7 +226,7 @@ StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) { StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
DLContext context; DLContext context;
TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
context.device_id = device.local_device_id(); context.device_id = device.local_hardware_id();
return context; return context;
} }
@ -241,14 +239,14 @@ StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client,
"DLPack CPU device type mismatch with PjRtClient platform %s", "DLPack CPU device type mismatch with PjRtClient platform %s",
client.platform_name()); client.platform_name());
} }
return client.LookupLocalDevice(context.device_id); return client.LookupAddressableDevice(context.device_id);
case kDLGPU: case kDLGPU:
if (client.platform_id() != kGpuId) { if (client.platform_id() != kGpuId) {
return InvalidArgument( return InvalidArgument(
"DLPack GPU device type mismatch with PjRtClient platform %s", "DLPack GPU device type mismatch with PjRtClient platform %s",
client.platform_name()); client.platform_name());
} }
return client.LookupLocalDevice(context.device_id); return client.LookupAddressableDevice(context.device_id);
default: default:
return InvalidArgument("Unknown/unsupported DLPack device type %d", return InvalidArgument("Unknown/unsupported DLPack device type %d",
context.device_type); context.device_type);
@ -297,7 +295,7 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
pack->tensor.manager_ctx = pack.get(); pack->tensor.manager_ctx = pack.get();
pack->tensor.deleter = DLPackTensorDeleter; pack->tensor.deleter = DLPackTensorDeleter;
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device())); TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
dt.ctx.device_id = buffer->buffer()->device()->local_device_id(); dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id();
dt.ndim = buffer->buffer()->on_host_shape().dimensions_size(); dt.ndim = buffer->buffer()->on_host_shape().dimensions_size();
TF_ASSIGN_OR_RETURN(dt.dtype, TF_ASSIGN_OR_RETURN(dt.dtype,
PrimitiveTypeToDLDataType( PrimitiveTypeToDLDataType(

View File

@ -230,8 +230,8 @@ OutfeedReceiverImpl::OutfeedReceiverImpl(
callback_ = callback; callback_ = callback;
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes; max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
for (const auto& client : clients) { for (const auto& client : clients) {
for (const auto& device : client->devices()) { for (auto device : client->devices()) {
devices_.push_back(device.get()); devices_.push_back(device);
} }
} }
CHECK_GT(devices_.size(), 0); CHECK_GT(devices_.size(), 0);
@ -342,11 +342,7 @@ StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
const PjRtDevice* device, const Shape& shape) { const PjRtDevice* device, const Shape& shape) {
std::shared_ptr<Literal> literal_shared; std::shared_ptr<Literal> literal_shared;
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, TF_ASSIGN_OR_RETURN(Literal literal, device->TransferFromOutfeed(shape));
device->GetLocalDeviceState());
TF_ASSIGN_OR_RETURN(Literal literal,
local_device->client()->TransferFromOutfeedLocal(
shape, local_device->device_ordinal()));
return absl::make_unique<Literal>(std::move(literal)); return absl::make_unique<Literal>(std::move(literal));
} }

View File

@ -86,8 +86,8 @@ StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
} }
StatusOr<py::dict> PyBuffer::CudaArrayInterface() const { StatusOr<py::dict> PyBuffer::CudaArrayInterface() const {
if (buffer_->device()->local_device_state()->executor()->platform_kind() != // TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs.
se::PlatformKind::kCuda) { if (buffer_->client()->platform_id() != kGpuId) {
return InvalidArgument( return InvalidArgument(
"__cuda_array_interface__ is only defined for NVidia GPU buffers."); "__cuda_array_interface__ is only defined for NVidia GPU buffers.");
} }

View File

@ -37,9 +37,10 @@ PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() { std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
std::vector<ClientAndPtr<PjRtDevice>> devices; std::vector<ClientAndPtr<PjRtDevice>> devices;
devices.reserve(pjrt_client_->devices().size()); auto span = pjrt_client_->devices();
for (const auto& device : pjrt_client_->devices()) { devices.reserve(span.size());
devices.push_back(WrapWithClient(shared_from_this(), device.get())); for (PjRtDevice* device : span) {
devices.push_back(WrapWithClient(shared_from_this(), device));
} }
return devices; return devices;
} }
@ -64,9 +65,9 @@ PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
result[r].resize(num_partitions); result[r].resize(num_partitions);
for (int p = 0; p < num_partitions; ++p) { for (int p = 0; p < num_partitions; ++p) {
int device_id = device_assignment(r, p); int device_id = device_assignment(r, p);
auto iter = pjrt_client_->id_to_device().find(device_id); TF_ASSIGN_OR_RETURN(PjRtDevice * device,
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id; pjrt_client_->LookupDevice(device_id));
result[r][p] = WrapWithClient(shared_from_this(), iter->second); result[r][p] = WrapWithClient(shared_from_this(), device);
} }
} }
return result; return result;
@ -80,9 +81,9 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
std::vector<ClientAndPtr<PjRtDevice>> result; std::vector<ClientAndPtr<PjRtDevice>> result;
for (int i = 0; i < num_replicas; ++i) { for (int i = 0; i < num_replicas; ++i) {
int device_id = device_assignment(i, 0); int device_id = device_assignment(i, 0);
auto iter = pjrt_client_->id_to_device().find(device_id); TF_ASSIGN_OR_RETURN(PjRtDevice * device,
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id; pjrt_client_->LookupDevice(device_id));
result.push_back(WrapWithClient(shared_from_this(), iter->second)); result.push_back(WrapWithClient(shared_from_this(), device));
} }
return result; return result;
} }
@ -95,8 +96,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
device = pjrt_client_->local_devices().front(); device = pjrt_client_->local_devices().front();
} }
CHECK(device != nullptr); CHECK(device != nullptr);
auto iter = pjrt_client_->id_to_device().find(device->id()); TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
if (iter->second != device) { pjrt_client_->LookupDevice(device->id()));
if (found_device != device) {
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend", return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
device->DebugString(), device->DebugString(),
pjrt_client_->platform_name()); pjrt_client_->platform_name());

View File

@ -97,7 +97,9 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
const std::string& platform_name() const { const std::string& platform_name() const {
return pjrt_client_->platform_name(); return pjrt_client_->platform_name();
} }
int local_device_count() const { return pjrt_client_->local_device_count(); } int addressable_device_count() const {
return pjrt_client_->addressable_device_count();
}
int device_count() const { return pjrt_client_->device_count(); } int device_count() const { return pjrt_client_->device_count(); }
int host_id() const { return pjrt_client_->host_id(); } int host_id() const { return pjrt_client_->host_id(); }

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/pytypes.h" #include "pybind11/pytypes.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
#include "tensorflow/compiler/xla/python/types.h"
namespace xla { namespace xla {
@ -106,59 +107,66 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const {
} }
} }
void PyTreeDef::FlattenInto(py::handle handle, void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
std::vector<py::object>& leaves) { absl::optional<py::function> leaf_predicate) {
Node node; Node node;
int start_num_nodes = traversal_.size(); int start_num_nodes = traversal_.size();
int start_num_leaves = leaves.size(); int start_num_leaves = leaves.size();
node.kind = GetKind(handle, &node.custom); if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
if (node.kind == Kind::kNone) { leaves.push_back(py::reinterpret_borrow<py::object>(handle));
// Nothing to do.
} else if (node.kind == Kind::kTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size();
for (py::handle entry : tuple) {
FlattenInto(entry, leaves);
}
} else if (node.kind == Kind::kList) {
py::list list = py::reinterpret_borrow<py::list>(handle);
node.arity = list.size();
for (py::handle entry : list) {
FlattenInto(entry, leaves);
}
} else if (node.kind == Kind::kDict) {
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
if (PyList_Sort(keys.ptr())) {
throw std::runtime_error("Dictionary key sort failed.");
}
for (py::handle key : keys) {
FlattenInto(dict[key], leaves);
}
node.arity = dict.size();
node.node_data = std::move(keys);
} else if (node.kind == Kind::kCustom) {
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
if (out.size() != 2) {
throw std::runtime_error(
"PyTree custom to_iterable function should return a pair");
}
node.node_data = out[1];
node.arity = 0;
for (py::handle entry : py::cast<py::iterable>(out[0])) {
++node.arity;
FlattenInto(entry, leaves);
}
} else if (node.kind == Kind::kNamedTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size();
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
for (py::handle entry : tuple) {
FlattenInto(entry, leaves);
}
} else { } else {
assert(node.kind == Kind::kLeaf); node.kind = GetKind(handle, &node.custom);
leaves.push_back(pybind11::reinterpret_borrow<py::object>(handle)); auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
FlattenInto(child, leaves, leaf_predicate);
};
if (node.kind == Kind::kNone) {
// Nothing to do.
} else if (node.kind == Kind::kTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size();
for (py::handle entry : tuple) {
recurse(entry);
}
} else if (node.kind == Kind::kList) {
py::list list = py::reinterpret_borrow<py::list>(handle);
node.arity = list.size();
for (py::handle entry : list) {
recurse(entry);
}
} else if (node.kind == Kind::kDict) {
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
if (PyList_Sort(keys.ptr())) {
throw std::runtime_error("Dictionary key sort failed.");
}
for (py::handle key : keys) {
recurse(dict[key]);
}
node.arity = dict.size();
node.node_data = std::move(keys);
} else if (node.kind == Kind::kCustom) {
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
if (out.size() != 2) {
throw std::runtime_error(
"PyTree custom to_iterable function should return a pair");
}
node.node_data = out[1];
node.arity = 0;
for (py::handle entry : py::cast<py::iterable>(out[0])) {
++node.arity;
recurse(entry);
}
} else if (node.kind == Kind::kNamedTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size();
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
for (py::handle entry : tuple) {
recurse(entry);
}
} else {
assert(node.kind == Kind::kLeaf);
leaves.push_back(py::reinterpret_borrow<py::object>(handle));
}
} }
node.num_nodes = traversal_.size() - start_num_nodes + 1; node.num_nodes = traversal_.size() - start_num_nodes + 1;
node.num_leaves = leaves.size() - start_num_leaves; node.num_leaves = leaves.size() - start_num_leaves;
@ -166,10 +174,10 @@ void PyTreeDef::FlattenInto(py::handle handle,
} }
/*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>> /*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
PyTreeDef::Flatten(py::handle x) { PyTreeDef::Flatten(py::handle x, absl::optional<py::function> leaf_predicate) {
std::vector<py::object> leaves; std::vector<py::object> leaves;
auto tree = absl::make_unique<PyTreeDef>(); auto tree = absl::make_unique<PyTreeDef>();
tree->FlattenInto(x, leaves); tree->FlattenInto(x, leaves, leaf_predicate);
return std::make_pair(std::move(leaves), std::move(tree)); return std::make_pair(std::move(leaves), std::move(tree));
} }
@ -618,7 +626,8 @@ std::string PyTreeDef::ToString() const {
void BuildPytreeSubmodule(py::module& m) { void BuildPytreeSubmodule(py::module& m) {
py::module pytree = m.def_submodule("pytree", "Python tree library"); py::module pytree = m.def_submodule("pytree", "Python tree library");
pytree.def("flatten", &PyTreeDef::Flatten); pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
py::arg("leaf_predicate") = absl::nullopt);
pytree.def("tuple", &PyTreeDef::Tuple); pytree.def("tuple", &PyTreeDef::Tuple);
pytree.def("all_leaves", &PyTreeDef::AllLeaves); pytree.def("all_leaves", &PyTreeDef::AllLeaves);

View File

@ -85,11 +85,13 @@ class PyTreeDef {
// Flattens a Pytree into a list of leaves and a PyTreeDef. // Flattens a Pytree into a list of leaves and a PyTreeDef.
static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>> static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
Flatten(pybind11::handle x); Flatten(pybind11::handle x,
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
// Recursive helper used to implement Flatten(). // Recursive helper used to implement Flatten().
void FlattenInto(pybind11::handle handle, void FlattenInto(
std::vector<pybind11::object>& leaves); pybind11::handle handle, std::vector<pybind11::object>& leaves,
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
// Tests whether the given list is a flat list of leaves. // Tests whether the given list is a flat list of leaves.
static bool AllLeaves(const pybind11::iterable& x); static bool AllLeaves(const pybind11::iterable& x);

View File

@ -37,6 +37,7 @@ cc_library(
"//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core/framework:allocator", "//tensorflow/core/framework:allocator",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:env", "//tensorflow/core/platform:env",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",

View File

@ -37,8 +37,8 @@ namespace xla {
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords, TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
int core_on_chip) int core_on_chip)
: xla::PjRtDevice(id, /*local_device_state=*/nullptr, : xla::PjRtStreamExecutorDevice(id, /*local_device_state=*/nullptr,
/*device_kind=*/"Cloud TPU", host_id), /*device_kind=*/"Cloud TPU", host_id),
coords_(coords), coords_(coords),
core_on_chip_(core_on_chip) {} core_on_chip_(core_on_chip) {}
@ -531,7 +531,7 @@ PyTpuExecutable::PyTpuExecutable(
<< "Inserting duplicate replica:" << replica; << "Inserting duplicate replica:" << replica;
executables_[replica] = executables_[replica] =
client_->driver()->LoadProgram(device_id, compiled_program.get(), {}); client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
addressable_device_logical_ids_.emplace_back(replica, partition); local_logical_device_ids_.emplace_back(replica, partition);
local_devices_.push_back(device); local_devices_.push_back(device);
} }
} }
@ -711,8 +711,8 @@ PyTpuExecutable::ExecuteOnLocalDevices(
// long time and we want all cores to be scheduled in parallel. // long time and we want all cores to be scheduled in parallel.
thread_pool->Schedule([this, i, argument_handles, &results, &results_lock, thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
&execute_semaphore]() { &execute_semaphore]() {
const int replica = addressable_device_logical_ids_[i].first; const int replica = local_logical_device_ids_[i].first;
const int partition = addressable_device_logical_ids_[i].second; const int partition = local_logical_device_ids_[i].second;
RunId run_id; RunId run_id;
auto result = ExecuteHelper(argument_handles, argument_handles[i], auto result = ExecuteHelper(argument_handles, argument_handles[i],
replica, partition, run_id); replica, partition, run_id);

View File

@ -32,13 +32,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/platform/threadpool.h"
namespace xla { namespace xla {
constexpr char kTpuPlatform[] = "tpu"; constexpr char kTpuPlatform[] = "tpu";
class TpuDevice : public PjRtDevice { class TpuDevice : public PjRtStreamExecutorDevice {
public: public:
TpuDevice(int id, int host_id, const std::array<int, 3>& coords, TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
int core_on_chip); int core_on_chip);
@ -298,9 +299,8 @@ class PyTpuExecutable {
return device_assignment_; return device_assignment_;
} }
const std::vector<std::pair<int, int>>& addressable_device_logical_ids() const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
const { return local_logical_device_ids_;
return addressable_device_logical_ids_;
} }
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const { const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
@ -341,14 +341,16 @@ class PyTpuExecutable {
// The replica and partition indices of device_assignment_ to be run by this // The replica and partition indices of device_assignment_ to be run by this
// client. On single-host platforms without partitioning, this is all replicas // client. On single-host platforms without partitioning, this is all replicas
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the // (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a // on multi-host platforms.
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8. // If there are 4 replicas and 2 partitions on a single host platform, size of
std::vector<std::pair<int, int>> addressable_device_logical_ids_; // local_logical_device_ids_ is 4*2 = 8.
std::vector<std::pair<int, int>> local_logical_device_ids_;
// local_devices_[i] is the Device to which addressable_device_logical_ids_[i] // local_devices_[i] is the Device to which local_logical_device_ids_[i] is
// is assigned. shared_ptrs instead of unique_ptrs to play well with the // assigned.
// Python bindings (see xla.cc). // shared_ptrs instead of unique_ptrs to play well with the Python bindings
// (see xla.cc).
std::vector<std::shared_ptr<PjRtDevice>> local_devices_; std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
xla::Shape result_shape_; xla::Shape result_shape_;

View File

@ -186,7 +186,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
py::class_<PyTpuExecutable>(m, "TpuExecutable") py::class_<PyTpuExecutable>(m, "TpuExecutable")
.def("local_logical_device_ids", .def("local_logical_device_ids",
&PyTpuExecutable::addressable_device_logical_ids) &PyTpuExecutable::local_logical_device_ids)
.def("local_devices", &PyTpuExecutable::local_devices) .def("local_devices", &PyTpuExecutable::local_devices)
.def_property_readonly("client", &PyTpuExecutable::client) .def_property_readonly("client", &PyTpuExecutable::client)
.def("size_of_generated_code_in_bytes", .def("size_of_generated_code_in_bytes",

View File

@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/python/types.h"
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/python/bfloat16.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/python/lib/core/bfloat16.h"
namespace xla { namespace xla {
@ -81,8 +81,8 @@ xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
case U64: case U64:
return py::dtype::of<uint64>(); return py::dtype::of<uint64>();
case BF16: { case BF16: {
TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype()); py::handle bfloat16(tensorflow::Bfloat16Dtype());
return py::dtype::from_args(bfloat16); return py::dtype::from_args(py::reinterpret_borrow<py::object>(bfloat16));
} }
case F16: case F16:
return py::dtype("e"); // PEP 3118 code for "float16 return py::dtype("e"); // PEP 3118 code for "float16
@ -237,10 +237,11 @@ StatusOr<py::object> LiteralToPython(std::shared_ptr<xla::Literal> literal) {
// We requested an array of uint16 since NumPy doesn't know how // We requested an array of uint16 since NumPy doesn't know how
// to produce our custom bfloat16 type. Reinterpret the array as bfloat16 // to produce our custom bfloat16 type. Reinterpret the array as bfloat16
// before handing it back to the caller. // before handing it back to the caller.
TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype()); py::handle bfloat16(tensorflow::Bfloat16Dtype());
bfloat16.inc_ref();
array = py::reinterpret_steal<py::array>( array = py::reinterpret_steal<py::array>(
PyArray_View(reinterpret_cast<PyArrayObject*>(array.ptr()), PyArray_View(reinterpret_cast<PyArrayObject*>(array.ptr()),
reinterpret_cast<PyArray_Descr*>(bfloat16.release().ptr()), reinterpret_cast<PyArray_Descr*>(bfloat16.ptr()),
static_cast<PyTypeObject*>(nullptr))); static_cast<PyTypeObject*>(nullptr)));
} }
return array; return array;

View File

@ -40,7 +40,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h" #include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tpu_client.h" #include "tensorflow/compiler/xla/pjrt/tpu_client.h"
#include "tensorflow/compiler/xla/python/bfloat16.h"
#include "tensorflow/compiler/xla/python/dlpack.h" #include "tensorflow/compiler/xla/python/dlpack.h"
#include "tensorflow/compiler/xla/python/jax_jit.h" #include "tensorflow/compiler/xla/python/jax_jit.h"
#include "tensorflow/compiler/xla/python/ops.h" #include "tensorflow/compiler/xla/python/ops.h"
@ -59,6 +58,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/python/lib/core/bfloat16.h"
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
namespace xla { namespace xla {
@ -110,6 +110,8 @@ PYBIND11_MODULE(xla_extension, m) {
throw std::runtime_error("Unable to initialize Numpy API"); throw std::runtime_error("Unable to initialize Numpy API");
} }
CHECK(tensorflow::RegisterNumpyBfloat16());
// Types // Types
py::enum_<PrimitiveType>(m, "PrimitiveType") py::enum_<PrimitiveType>(m, "PrimitiveType")
.value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID) .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
@ -132,7 +134,8 @@ PYBIND11_MODULE(xla_extension, m) {
.value("OPAQUE_TYPE", OPAQUE_TYPE) .value("OPAQUE_TYPE", OPAQUE_TYPE)
.value("TOKEN", TOKEN); .value("TOKEN", TOKEN);
m.def("bfloat16_dtype", Bfloat16Dtype); m.def("bfloat16_dtype",
[]() { return py::handle(tensorflow::Bfloat16Dtype()); });
// Must be before PyClient.compile. // Must be before PyClient.compile.
BuildXlaCompilerSubmodule(m); BuildXlaCompilerSubmodule(m);
@ -149,7 +152,10 @@ PYBIND11_MODULE(xla_extension, m) {
.def_property_readonly("host_id", &PjRtDevice::host_id, .def_property_readonly("host_id", &PjRtDevice::host_id,
"Integer ID of this device's host.\n\n" "Integer ID of this device's host.\n\n"
"This is always 0 except on multi-host platforms.") "This is always 0 except on multi-host platforms.")
.def_property_readonly("platform", &PjRtDevice::platform_name) .def_property_readonly("platform",
[](const PjRtDevice& device) {
return device.client()->platform_name();
})
.def_property_readonly("device_kind", &PjRtDevice::device_kind) .def_property_readonly("device_kind", &PjRtDevice::device_kind)
.def_property_readonly( .def_property_readonly(
"client", "client",
@ -234,7 +240,7 @@ PYBIND11_MODULE(xla_extension, m) {
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client"); py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
py_local_client.def_property_readonly("platform", &PyClient::platform_name) py_local_client.def_property_readonly("platform", &PyClient::platform_name)
.def("device_count", &PyClient::device_count) .def("device_count", &PyClient::device_count)
.def("local_device_count", &PyClient::local_device_count) .def("local_device_count", &PyClient::addressable_device_count)
.def("devices", &PyClient::Devices) .def("devices", &PyClient::Devices)
.def("local_devices", &PyClient::LocalDevices) .def("local_devices", &PyClient::LocalDevices)
.def("host_id", &PyClient::host_id) .def("host_id", &PyClient::host_id)
@ -381,10 +387,10 @@ PYBIND11_MODULE(xla_extension, m) {
[](PyExecutable* exec) { [](PyExecutable* exec) {
auto span = exec->addressable_device_logical_ids(); auto span = exec->addressable_device_logical_ids();
// Not on dispatch critical path, so ok to have heap allocation. // Not on dispatch critical path, so ok to have heap allocation.
std::vector<std::pair<int, int>> addressable_device_logical_ids; std::vector<std::pair<int, int>> addressable_device_logic_ids;
addressable_device_logical_ids.reserve(span.size()); addressable_device_logic_ids.reserve(span.size());
for (const auto& logical_device_id : span) { for (const auto& logical_device_id : span) {
addressable_device_logical_ids.push_back(std::make_pair( addressable_device_logic_ids.push_back(std::make_pair(
logical_device_id.replica, logical_device_id.partition)); logical_device_id.replica, logical_device_id.partition));
} }
}) })

View File

@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/threadpool.h"
namespace xla { namespace xla {
@ -158,6 +159,18 @@ class AotCompilationMetadata {
// platform. // platform.
class Compiler { class Compiler {
public: public:
struct CompileOptions {
// If device_allocator is not null, the compiler may use it to allocate temp
// space on the device for use during compilation. For example, the
// compiler may allocate buffers on the device and then run variants of a
// given algorithm over those buffers, to see which variant is fastest. Any
// space allocated will be deallocated before the compilation returns.
se::DeviceMemoryAllocator* device_allocator = nullptr;
// An optional thread pool for parallel compilation.
tensorflow::thread::ThreadPool* thread_pool = nullptr;
};
virtual ~Compiler() {} virtual ~Compiler() {}
// Returns the ID of the platform that this compiler targets. // Returns the ID of the platform that this compiler targets.
@ -165,31 +178,24 @@ class Compiler {
// Runs Hlo passes to optimize the given Hlo module, returns the optimized // Runs Hlo passes to optimize the given Hlo module, returns the optimized
// module. // module.
//
// If device_allocator is not null, the compiler may use it to allocate temp
// space on the device for use during compilation. For example, the compiler
// may allocate buffers on the device and then run variants of a given
// algorithm over those buffers, to see which variant is fastest. Any space
// allocated should be deallocated before this function returns.
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses( virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor, std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) = 0; const CompileOptions& options) = 0;
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) {
return RunHloPasses(std::move(module), executor,
CompileOptions{device_allocator});
}
// Runs HLO passes to optimize the given HloModule, perform scheduling and // Runs HLO passes to optimize the given HloModule, perform scheduling and
// buffer assignment, returns the optimized module and the buffer assignments. // buffer assignment, returns the optimized module and the buffer assignments.
// This interface is intentionally narrow. // This interface is intentionally narrow.
//
// If device_allocator is not null, the compiler may use it to allocate temp
// space on the device for use during compilation. For example, the compiler
// may allocate buffers on the device and then run variants of a given
// algorithm over those buffers, to see which variant is fastest. Any space
// allocated should be deallocated before this function returns.
virtual StatusOr< virtual StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>> std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module, RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
se::StreamExecutor* executor, se::StreamExecutor* executor, bool optimize,
se::DeviceMemoryAllocator* device_allocator, const CompileOptions& options) {
bool optimize) {
return Unimplemented("This compiler does not support this method"); return Unimplemented("This compiler does not support this method");
} }
@ -201,24 +207,33 @@ class Compiler {
// //
// The compiler may optionally specialize to the individual device // The compiler may optionally specialize to the individual device
// (not just type of device) indicated by the executor. // (not just type of device) indicated by the executor.
//
// device_allocator is optional; see RunHloPasses.
virtual StatusOr<std::unique_ptr<Executable>> RunBackend( virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor, std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) = 0; const CompileOptions& options) = 0;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) {
return RunBackend(std::move(module), executor,
CompileOptions{device_allocator});
}
// Compiles a set of HLO modules that can run in parallel, potentially // Compiles a set of HLO modules that can run in parallel, potentially
// communicating data between the modules, and returns a corresponding // communicating data between the modules, and returns a corresponding
// sequence of executable objects. // sequence of executable objects.
// //
// device_allocator is optional; see RunHloPasses.
//
// TODO(b/68666782): Remove this method after adding support for multiple // TODO(b/68666782): Remove this method after adding support for multiple
// modules to RunHloPasses and RunBackends. // modules to RunHloPasses and RunBackends.
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group, std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec, std::vector<std::vector<se::StreamExecutor*>> stream_exec,
se::DeviceMemoryAllocator* device_allocator) = 0; const CompileOptions& options) = 0;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
return Compile(std::move(module_group), stream_exec,
CompileOptions{device_allocator});
}
// Returns the backend configurations that the backend will consider for the // Returns the backend configurations that the backend will consider for the
// given HLO. Returns no configurations if the backend does not support // given HLO. Returns no configurations if the backend does not support

View File

@ -553,7 +553,7 @@ Status CreateHloProfilingArtifacts(
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses( StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/, std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
se::DeviceMemoryAllocator* /*device_allocator*/) { const CompileOptions& /*options*/) {
std::unique_ptr<llvm::TargetMachine> jit_target_machine = std::unique_ptr<llvm::TargetMachine> jit_target_machine =
SimpleOrcJIT::InferTargetMachineForJIT( SimpleOrcJIT::InferTargetMachineForJIT(
CompilerTargetOptions(module->config()), CompilerTargetOptions(module->config()),
@ -566,12 +566,13 @@ StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
StatusOr< StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>> std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
CpuCompiler::RunHloPassesAndBufferAssignement( CpuCompiler::RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
std::unique_ptr<HloModule> module, se::StreamExecutor* executor, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator, bool optimize) { bool optimize,
const CompileOptions& options) {
if (optimize) { if (optimize) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(module,
module, RunHloPasses(std::move(module), executor, device_allocator)); RunHloPasses(std::move(module), executor, options));
} }
// Select an order for emitting the HLO instructions for each computation. // Select an order for emitting the HLO instructions for each computation.
@ -632,7 +633,7 @@ struct OrcJITPostCompilationHook {
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* /*device_allocator*/) { const CompileOptions& options) {
VLOG(1) << "Compiling: " << module->name(); VLOG(1) << "Compiling: " << module->name();
XLA_SCOPED_LOGGING_TIMER( XLA_SCOPED_LOGGING_TIMER(
absl::StrFormat("Compiling [%s] for CPU using JIT", module->name())); absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));

View File

@ -134,18 +134,17 @@ class CpuCompiler : public LLVMCompiler {
StatusOr<std::unique_ptr<HloModule>> RunHloPasses( StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
StatusOr< StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>> std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module, RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
se::StreamExecutor* executor, se::StreamExecutor* executor, bool optimize,
se::DeviceMemoryAllocator* device_allocator, const CompileOptions& options) override;
bool optimize) override;
StatusOr<std::unique_ptr<Executable>> RunBackend( StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,

View File

@ -440,7 +440,7 @@ filegroup(
name = "nccl_collective_thunk_src", name = "nccl_collective_thunk_src",
srcs = if_nccl( srcs = if_nccl(
["nccl_collective_thunk.cc"], ["nccl_collective_thunk.cc"],
["dummy_collective_thunk.cc"], ["nccl_collective_thunk_dummy.cc"],
), ),
) )
@ -448,7 +448,7 @@ tf_cuda_library(
name = "nccl_collective_thunk", name = "nccl_collective_thunk",
srcs = if_cuda_or_rocm( srcs = if_cuda_or_rocm(
[":nccl_collective_thunk_src"], [":nccl_collective_thunk_src"],
["dummy_collective_thunk.cc"], ["nccl_collective_thunk_dummy.cc"],
), ),
hdrs = ["nccl_collective_thunk.h"], hdrs = ["nccl_collective_thunk.h"],
deps = [ deps = [
@ -480,7 +480,7 @@ filegroup(
name = "nccl_all_gather_thunk_src", name = "nccl_all_gather_thunk_src",
srcs = if_nccl( srcs = if_nccl(
["nccl_all_gather_thunk.cc"], ["nccl_all_gather_thunk.cc"],
["dummy_all_gather_thunk.cc"], ["nccl_all_gather_thunk_dummy.cc"],
), ),
) )
@ -488,7 +488,7 @@ tf_cuda_library(
name = "nccl_all_gather_thunk", name = "nccl_all_gather_thunk",
srcs = if_cuda_or_rocm( srcs = if_cuda_or_rocm(
[":nccl_all_gather_thunk_src"], [":nccl_all_gather_thunk_src"],
["dummy_all_gather_thunk.cc"], ["nccl_all_gather_thunk_dummy.cc"],
), ),
hdrs = ["nccl_all_gather_thunk.h"], hdrs = ["nccl_all_gather_thunk.h"],
deps = [ deps = [
@ -520,7 +520,7 @@ filegroup(
name = "nccl_all_reduce_thunk_src", name = "nccl_all_reduce_thunk_src",
srcs = if_nccl( srcs = if_nccl(
["nccl_all_reduce_thunk.cc"], ["nccl_all_reduce_thunk.cc"],
["dummy_all_reduce_thunk.cc"], ["nccl_all_reduce_thunk_dummy.cc"],
), ),
) )
@ -528,7 +528,7 @@ tf_cuda_library(
name = "nccl_all_reduce_thunk", name = "nccl_all_reduce_thunk",
srcs = if_cuda_or_rocm( srcs = if_cuda_or_rocm(
[":nccl_all_reduce_thunk_src"], [":nccl_all_reduce_thunk_src"],
["dummy_all_reduce_thunk.cc"], ["nccl_all_reduce_thunk_dummy.cc"],
), ),
hdrs = ["nccl_all_reduce_thunk.h"], hdrs = ["nccl_all_reduce_thunk.h"],
deps = [ deps = [
@ -560,7 +560,7 @@ filegroup(
name = "nccl_all_to_all_thunk_src", name = "nccl_all_to_all_thunk_src",
srcs = if_nccl( srcs = if_nccl(
["nccl_all_to_all_thunk.cc"], ["nccl_all_to_all_thunk.cc"],
["dummy_all_to_all_thunk.cc"], ["nccl_all_to_all_thunk_dummy.cc"],
), ),
) )
@ -568,7 +568,7 @@ tf_cuda_library(
name = "nccl_all_to_all_thunk", name = "nccl_all_to_all_thunk",
srcs = if_cuda_or_rocm( srcs = if_cuda_or_rocm(
[":nccl_all_to_all_thunk_src"], [":nccl_all_to_all_thunk_src"],
["dummy_all_to_all_thunk.cc"], ["nccl_all_to_all_thunk_dummy.cc"],
), ),
hdrs = ["nccl_all_to_all_thunk.h"], hdrs = ["nccl_all_to_all_thunk.h"],
deps = [ deps = [
@ -600,7 +600,7 @@ filegroup(
name = "nccl_test_utils_src", name = "nccl_test_utils_src",
srcs = if_nccl( srcs = if_nccl(
["nccl_test_utils.cc"], ["nccl_test_utils.cc"],
["dummy_nccl_test_utils.cc"], ["nccl_test_utils_dummy.cc"],
), ),
) )
@ -608,7 +608,7 @@ tf_cuda_library(
name = "nccl_test_utils", name = "nccl_test_utils",
srcs = if_cuda_or_rocm( srcs = if_cuda_or_rocm(
[":nccl_test_utils_src"], [":nccl_test_utils_src"],
["dummy_nccl_test_utils.cc"], ["nccl_test_utils_dummy.cc"],
), ),
hdrs = ["nccl_test_utils.h"], hdrs = ["nccl_test_utils.h"],
deps = [ deps = [
@ -1452,7 +1452,11 @@ cc_library(
"//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/stream_executor:stream_executor_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:AsmParser",
"@llvm-project//llvm:BitReader",
"@llvm-project//llvm:BitWriter",
"@llvm-project//llvm:Core", "@llvm-project//llvm:Core",
"@llvm-project//llvm:TransformUtils",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
@ -1517,7 +1521,7 @@ cc_library(
"//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/stream_executor:stream_executor_headers",
"//tensorflow/stream_executor/cuda:cuda_diagnostics", "//tensorflow/stream_executor/cuda:cuda_diagnostics",
"//tensorflow/stream_executor/gpu:asm_compiler", "//tensorflow/stream_executor/gpu:asm_compiler",
]), ]) + ["//tensorflow/stream_executor/gpu:gpu_driver_header"],
) )
cc_library( cc_library(

View File

@ -108,12 +108,17 @@ StatusOr<std::pair<std::string, std::vector<uint8>>>
AMDGPUCompiler::CompileTargetBinary(const HloModule* module, AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
llvm::Module* llvm_module, llvm::Module* llvm_module,
GpuVersion gpu_version, GpuVersion gpu_version,
se::StreamExecutor* stream_exec) { se::StreamExecutor* stream_exec,
bool relocatable) {
if (rocdl_dir_.empty()) { if (rocdl_dir_.empty()) {
// Compute rocdl_dir_ just once and cache it in this member. // Compute rocdl_dir_ just once and cache it in this member.
rocdl_dir_ = GetROCDLDir(module->config()); rocdl_dir_ = GetROCDLDir(module->config());
} }
if (relocatable) {
return Unimplemented("relocatable target binary is not implemented");
}
std::vector<uint8> hsaco; std::vector<uint8> hsaco;
{ {
XLA_SCOPED_LOGGING_TIMER( XLA_SCOPED_LOGGING_TIMER(

View File

@ -41,7 +41,8 @@ class AMDGPUCompiler : public GpuCompiler {
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary( StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
const HloModule* hlo_module, llvm::Module* llvm_module, const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override; GpuVersion gpu_version, se::StreamExecutor* stream_exec,
bool relocatable) override;
private: private:
// The parent directory of ROCm-Device-Libs IR libraries. // The parent directory of ROCm-Device-Libs IR libraries.

View File

@ -24,11 +24,15 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/LLVMContext.h" #include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h" #include "llvm/IR/Verifier.h"
#include "llvm/Transforms/Utils/SplitModule.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project
#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/protobuf_util.h"
@ -114,11 +118,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/subprocess.h"
#include "tensorflow/core/platform/threadpool.h"
#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/env_var.h"
@ -470,14 +476,14 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment(
StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses( StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) { const CompileOptions& options) {
// We dump the post-optimization HLO in RunBackend so no need to dump it here. // We dump the post-optimization HLO in RunBackend so no need to dump it here.
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
tensorflow::profiler::TraceMe activity( tensorflow::profiler::TraceMe activity(
[&] { return absl::StrCat("HLO Transforms:", module->name()); }, [&] { return absl::StrCat("HLO Transforms:", module->name()); },
tensorflow::profiler::TraceMeLevel::kInfo); tensorflow::profiler::TraceMeLevel::kInfo);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
OptimizeHloModule(module.get(), stream_exec, device_allocator)); OptimizeHloModule(module.get(), stream_exec, options.device_allocator));
TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
@ -494,10 +500,10 @@ StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>> std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
GpuCompiler::RunHloPassesAndBufferAssignement( GpuCompiler::RunHloPassesAndBufferAssignement(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* executor, std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator, bool optimize) { bool optimize, const CompileOptions& options) {
if (optimize) { if (optimize) {
TF_ASSIGN_OR_RETURN(hlo_module, RunHloPasses(std::move(hlo_module), TF_ASSIGN_OR_RETURN(hlo_module,
executor, device_allocator)); RunHloPasses(std::move(hlo_module), executor, options));
} }
std::unique_ptr<StreamAssignment> stream_assignment = std::unique_ptr<StreamAssignment> stream_assignment =
@ -641,24 +647,133 @@ static Status CompileModuleToLlvmIrImpl(
return Status::OK(); return Status::OK();
} }
StatusOr<std::pair<std::string, std::vector<uint8>>>
GpuCompiler::CompileToTargetBinary(const HloModule& module,
std::unique_ptr<llvm::Module> llvm_module,
se::StreamExecutor* stream_exec,
const CompileOptions& options) {
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
const auto compile_single_module =
[this, stream_exec, &module](
llvm::Module* llvm_module,
bool relocatable) -> StatusOr<BackendCompileResult> {
{
XLA_SCOPED_LOGGING_TIMER(
"GpuCompiler::RunBackend - Running LLVM verifier");
std::string err;
llvm::raw_string_ostream err_stream(err);
// verifyModule() returns true if the module is broken.
TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
<< "Invalid LLVM IR before optimizations:\n"
<< err_stream.str()
<< "\nThis probably indicates a bug in the HLO -> LLVM IR "
"lowering. "
"Rerun with --xla_dump_to to get the IR and looks for files "
"with "
"name containing: *"
<< FilenameFor(module, "", "") << "*";
}
GpuVersion gpu_version = GetGpuVersion(stream_exec);
return CompileTargetBinary(&module, llvm_module, gpu_version, stream_exec,
relocatable);
};
tensorflow::thread::ThreadPool* thread_pool = options.thread_pool;
if (!thread_pool) {
return compile_single_module(llvm_module.get(), /*relocatable=*/false);
}
// Test whether LinkModules is supported.
if (this->LinkModules(stream_exec, {}).status().code() ==
tensorflow::error::Code::UNIMPLEMENTED) {
return compile_single_module(llvm_module.get(), /*relocatable=*/false);
}
std::vector<std::unique_ptr<llvm::Module>> llvm_modules;
int num_functions = 0;
for (llvm::Function& func : llvm_module->functions()) {
if (!func.isDeclaration() &&
func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) {
num_functions++;
}
}
llvm::SplitModule(
std::move(llvm_module),
std::max<unsigned>(
1, std::min<unsigned>(thread_pool->NumThreads(), num_functions)),
[&](std::unique_ptr<llvm::Module> module) {
llvm_modules.push_back(std::move(module));
},
/*PreserveLocals=*/true);
std::vector<StatusOr<BackendCompileResult>> compile_results(
llvm_modules.size());
tensorflow::BlockingCounter counter(llvm_modules.size());
for (int i = 0; i < llvm_modules.size(); i++) {
thread_pool->Schedule([&compile_results, compile_single_module, i,
&llvm_modules, &counter] {
llvm::Module* original_module = llvm_modules[i].get();
llvm::LLVMContext context;
std::string buffer;
llvm::raw_string_ostream error(buffer);
llvm::DiagnosticPrinterRawOStream printer(error);
auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
void* Context) {
auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
diag_info.print(*printer);
};
context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
std::unique_ptr<llvm::Module> new_llvm_module;
{
std::string ir;
{
llvm::raw_string_ostream os(ir);
original_module->print(os, nullptr);
}
llvm::SMDiagnostic err;
new_llvm_module = llvm::parseAssemblyString(ir, err, context);
}
compile_results[i] =
compile_single_module(new_llvm_module.get(), /*relocatable=*/true);
counter.DecrementCount();
});
}
counter.Wait();
std::string ptx_snippets;
std::vector<std::vector<uint8>> submodule_compile_results;
for (auto& maybe_result : compile_results) {
TF_ASSIGN_OR_RETURN(auto result, maybe_result);
if (result.second.empty()) {
continue;
}
ptx_snippets += result.first;
ptx_snippets += "\n";
submodule_compile_results.push_back(result.second);
}
TF_ASSIGN_OR_RETURN(
std::vector<uint8> backend_result,
this->LinkModules(stream_exec, std::move(submodule_compile_results)));
return std::make_pair(ptx_snippets, backend_result);
}
StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend( StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) { const CompileOptions& options) {
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend"); XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
auto slow_compile_alarm = SlowCompilationAlarm(); auto slow_compile_alarm = SlowCompilationAlarm();
TF_RET_CHECK(stream_exec != nullptr); TF_RET_CHECK(stream_exec != nullptr);
llvm::LLVMContext llvm_context; llvm::LLVMContext llvm_context;
std::string buffer;
llvm::raw_string_ostream error(buffer);
llvm::DiagnosticPrinterRawOStream printer(error);
auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
void* Context) {
auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
diag_info.print(*printer);
};
llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
GpuDeviceInfo gpu_device_info; GpuDeviceInfo gpu_device_info;
gpu_device_info.threads_per_block_limit = gpu_device_info.threads_per_block_limit =
@ -724,34 +839,16 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false); llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false);
{
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier");
std::string err;
llvm::raw_string_ostream err_stream(err);
// verifyModule() returns true if the module is broken.
TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
<< "Invalid LLVM IR before optimizations:\n"
<< err_stream.str()
<< "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
"Rerun with --xla_dump_to to get the IR and looks for files with "
"name containing: *"
<< FilenameFor(*module, "", "") << "*";
}
GpuVersion gpu_version = GetGpuVersion(stream_exec);
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>; using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result, TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
CompileTargetBinary(module.get(), llvm_module.get(), CompileToTargetBinary(*module, std::move(llvm_module),
gpu_version, stream_exec)); stream_exec, options));
if (DumpingEnabledForHloModule(*module)) { if (DumpingEnabledForHloModule(*module)) {
DumpToFileInDirOrStdout(*module, "", "thunk_schedule", DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
thunk_schedule->ToString()); thunk_schedule->ToString());
} }
GpuVersion gpu_version = GetGpuVersion(stream_exec);
auto* gpu_executable = new GpuExecutable( auto* gpu_executable = new GpuExecutable(
backend_result.first, backend_result.second, gpu_version, backend_result.first, backend_result.second, gpu_version,
std::move(thunk_schedule), std::move(module), std::move(thunk_schedule), std::move(module),

View File

@ -53,14 +53,13 @@ class GpuCompiler : public LLVMCompiler {
StatusOr<std::unique_ptr<HloModule>> RunHloPasses( StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
StatusOr< StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>> std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> hlo_module, RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> hlo_module,
se::StreamExecutor* executor, se::StreamExecutor* executor, bool optimize,
se::DeviceMemoryAllocator* device_allocator, const CompileOptions& options) override;
bool optimize) override;
Status OptimizeHloModule(HloModule* hlo_module, Status OptimizeHloModule(HloModule* hlo_module,
se::StreamExecutor* stream_exec, se::StreamExecutor* stream_exec,
@ -84,19 +83,23 @@ class GpuCompiler : public LLVMCompiler {
virtual StatusOr<std::pair<std::string, std::vector<uint8>>> virtual StatusOr<std::pair<std::string, std::vector<uint8>>>
CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module, CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, GpuVersion gpu_version, se::StreamExecutor* stream_exec,
se::StreamExecutor* stream_exec) = 0; bool relocatable) = 0;
Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
StatusOr<std::unique_ptr<Executable>> RunBackend( StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
AotCompilationOptions const& options) override; AotCompilationOptions const& options) override;
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileToTargetBinary(
const HloModule& module, std::unique_ptr<llvm::Module> llvm_module,
se::StreamExecutor* stream_exec, const CompileOptions& options);
se::Platform::Id PlatformId() const override { return platform_id_; } se::Platform::Id PlatformId() const override { return platform_id_; }
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
@ -116,6 +119,12 @@ class GpuCompiler : public LLVMCompiler {
} }
private: private:
virtual StatusOr<std::vector<uint8>> LinkModules(
se::StreamExecutor* stream_exec,
std::vector<std::vector<uint8>> modules) {
return Unimplemented("LinkModules is not implemented.");
}
se::Platform::Id platform_id_; se::Platform::Id platform_id_;
// The triple that represents our target. // The triple that represents our target.

View File

@ -51,6 +51,7 @@ limitations under the License.
#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
#include "tensorflow/stream_executor/gpu/asm_compiler.h" #include "tensorflow/stream_executor/gpu/asm_compiler.h"
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
namespace xla { namespace xla {
namespace gpu { namespace gpu {
@ -299,7 +300,8 @@ StatusOr<std::pair<std::string, std::vector<uint8>>>
NVPTXCompiler::CompileTargetBinary(const HloModule* module, NVPTXCompiler::CompileTargetBinary(const HloModule* module,
llvm::Module* llvm_module, llvm::Module* llvm_module,
GpuVersion gpu_version, GpuVersion gpu_version,
se::StreamExecutor* stream_exec) { se::StreamExecutor* stream_exec,
bool relocatable) {
std::pair<int, int> compute_capability = std::pair<int, int> compute_capability =
absl::get<std::pair<int, int>>(gpu_version); absl::get<std::pair<int, int>>(gpu_version);
@ -338,7 +340,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module,
std::vector<uint8> cubin = CompileGpuAsmOrGetCachedResult( std::vector<uint8> cubin = CompileGpuAsmOrGetCachedResult(
stream_exec, ptx, compute_capability.first, compute_capability.second, stream_exec, ptx, compute_capability.first, compute_capability.second,
module->config()); module->config(), relocatable);
return std::pair<std::string, std::vector<uint8>>(std::move(ptx), return std::pair<std::string, std::vector<uint8>>(std::move(ptx),
std::move(cubin)); std::move(cubin));
@ -346,7 +348,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module,
std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult( std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
se::StreamExecutor* stream_exec, const string& ptx, int cc_major, se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
int cc_minor, const HloModuleConfig& hlo_module_config) { int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable) {
XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompileGpuAsmOrGetCachedResult"); XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompileGpuAsmOrGetCachedResult");
tensorflow::profiler::TraceMe activity( tensorflow::profiler::TraceMe activity(
"PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo); "PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo);
@ -361,7 +363,7 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
tensorflow::mutex_lock lock(mutex_); tensorflow::mutex_lock lock(mutex_);
std::tie(iter, inserted) = compilation_cache_.emplace( std::tie(iter, inserted) = compilation_cache_.emplace(
std::piecewise_construct, std::piecewise_construct,
std::forward_as_tuple(ptx, cc_major, cc_minor), std::forward_as_tuple(ptx, cc_major, cc_minor, relocatable),
std::forward_as_tuple()); std::forward_as_tuple());
cache_ptx = &iter->first.ptx; cache_ptx = &iter->first.ptx;
cache_value = &iter->second; cache_value = &iter->second;
@ -375,9 +377,13 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
if (inserted) { if (inserted) {
CHECK(!cache_value->compilation_done); CHECK(!cache_value->compilation_done);
if (!ptx.empty()) { if (!ptx.empty()) {
StatusOr<std::vector<uint8>> maybe_cubin = auto ptxas_config = PtxOptsFromConfig(hlo_module_config);
se::CompileGpuAsm(stream_exec->device_ordinal(), cache_ptx->c_str(), if (relocatable) {
PtxOptsFromConfig(hlo_module_config)); ptxas_config.extra_flags.push_back("-c");
}
StatusOr<std::vector<uint8>> maybe_cubin = se::CompileGpuAsm(
stream_exec->device_ordinal(), cache_ptx->c_str(), ptxas_config);
if (maybe_cubin.ok()) { if (maybe_cubin.ok()) {
cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie();
VLOG(2) << "Compiled PTX size:" << ptx.size() VLOG(2) << "Compiled PTX size:" << ptx.size()
@ -445,5 +451,17 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
return cache_value->cubin_data; return cache_value->cubin_data;
} }
StatusOr<std::vector<uint8>> NVPTXCompiler::LinkModules(
se::StreamExecutor* stream_exec, std::vector<std::vector<uint8>> modules) {
std::vector<stream_executor::CubinOrPTXImage> images;
images.reserve(modules.size());
for (auto& module : modules) {
images.push_back({"", std::move(module)});
}
return LinkGpuAsm(static_cast<se::gpu::GpuContext*>(
stream_exec->implementation()->GpuContextHack()),
images);
}
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla

View File

@ -52,9 +52,14 @@ class NVPTXCompiler : public GpuCompiler {
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary( StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
const HloModule* hlo_module, llvm::Module* llvm_module, const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override; GpuVersion gpu_version, se::StreamExecutor* stream_exec,
bool relocatable) override;
private: private:
StatusOr<std::vector<uint8>> LinkModules(
se::StreamExecutor* stream_exec,
std::vector<std::vector<uint8>> modules) override;
tensorflow::mutex mutex_; tensorflow::mutex mutex_;
// When compiling an HLO module, we need to find a path to the nvvm libdevice // When compiling an HLO module, we need to find a path to the nvvm libdevice
@ -71,7 +76,7 @@ class NVPTXCompiler : public GpuCompiler {
// compiled cubin. If compilation was unsuccessful, returns an empty vector. // compiled cubin. If compilation was unsuccessful, returns an empty vector.
std::vector<uint8> CompileGpuAsmOrGetCachedResult( std::vector<uint8> CompileGpuAsmOrGetCachedResult(
se::StreamExecutor* stream_exec, const string& ptx, int cc_major, se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
int cc_minor, const HloModuleConfig& hlo_module_config); int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable);
// The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor}
// -> cubin so we don't recompile the same ptx twice. This is important for // -> cubin so we don't recompile the same ptx twice. This is important for
@ -86,24 +91,32 @@ class NVPTXCompiler : public GpuCompiler {
// If compiling the ptx fails, we return an empty cubin, cross our fingers, // If compiling the ptx fails, we return an empty cubin, cross our fingers,
// and leave compilation up to the driver. // and leave compilation up to the driver.
struct CompilationCacheKey { struct CompilationCacheKey {
CompilationCacheKey(std::string ptx, int cc_major, int cc_minor) CompilationCacheKey(std::string ptx, int cc_major, int cc_minor,
: ptx(std::move(ptx)), cc_major(cc_major), cc_minor(cc_minor) {} bool relocatable)
: ptx(std::move(ptx)),
cc_major(cc_major),
cc_minor(cc_minor),
relocatable(relocatable) {}
string ptx; string ptx;
int cc_major; int cc_major;
int cc_minor; int cc_minor;
bool relocatable;
}; };
struct CompilationCacheHash { struct CompilationCacheHash {
size_t operator()(const CompilationCacheKey& key) const { size_t operator()(const CompilationCacheKey& key) const {
return tensorflow::Hash64Combine( return tensorflow::Hash64Combine(
tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx), key.cc_major), tensorflow::Hash64Combine(
key.cc_minor); tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx),
key.cc_major),
key.cc_minor),
key.relocatable);
} }
}; };
struct CompilationCacheEq { struct CompilationCacheEq {
size_t operator()(const CompilationCacheKey& a, size_t operator()(const CompilationCacheKey& a,
const CompilationCacheKey& b) const { const CompilationCacheKey& b) const {
return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor && return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor &&
a.ptx == b.ptx; a.ptx == b.ptx && a.relocatable == b.relocatable;
} }
}; };
struct CompilationCacheValue { struct CompilationCacheValue {

View File

@ -95,7 +95,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses( StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/, std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/,
se::DeviceMemoryAllocator* /*device_allocator*/) { const CompileOptions& /*options*/) {
VLOG(1) << "Run hlo passes on graph " << hlo_module->name(); VLOG(1) << "Run hlo passes on graph " << hlo_module->name();
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
return std::move(hlo_module); return std::move(hlo_module);
@ -103,7 +103,7 @@ StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend( StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* /*device_allocator*/) { const CompileOptions& /*options*/) {
TF_RET_CHECK(stream_exec != nullptr); TF_RET_CHECK(stream_exec != nullptr);
VLOG(1) << "Run backend " << hlo_module->name(); VLOG(1) << "Run backend " << hlo_module->name();
@ -128,7 +128,7 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
std::unique_ptr<HloModuleGroup> module_group, std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec, std::vector<std::vector<se::StreamExecutor*>> stream_exec,
se::DeviceMemoryAllocator* device_allocator) { const CompileOptions& options) {
if (module_group->empty()) { if (module_group->empty()) {
return std::vector<std::unique_ptr<Executable>>(); return std::vector<std::unique_ptr<Executable>>();
} }
@ -141,12 +141,10 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
"Unexpected number of StreamExecutor's."); "Unexpected number of StreamExecutor's.");
} }
auto hlo_modules = module_group->ConsumeModules(); auto hlo_modules = module_group->ConsumeModules();
TF_ASSIGN_OR_RETURN(auto module, TF_ASSIGN_OR_RETURN(auto module, RunHloPasses(std::move(hlo_modules[0]),
RunHloPasses(std::move(hlo_modules[0]), stream_exec[0][0], stream_exec[0][0], options));
device_allocator)); TF_ASSIGN_OR_RETURN(auto executable, RunBackend(std::move(module),
TF_ASSIGN_OR_RETURN( stream_exec[0][0], options));
auto executable,
RunBackend(std::move(module), stream_exec[0][0], device_allocator));
std::vector<std::unique_ptr<Executable>> ret; std::vector<std::unique_ptr<Executable>> ret;
ret.push_back(std::move(executable)); ret.push_back(std::move(executable));
return std::move(ret); return std::move(ret);

View File

@ -45,14 +45,14 @@ class InterpreterCompiler : public Compiler {
StatusOr<std::unique_ptr<HloModule>> RunHloPasses( StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
StatusOr<std::unique_ptr<Executable>> RunBackend( StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group, std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec, std::vector<std::vector<se::StreamExecutor*>> stream_exec,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,

View File

@ -24,7 +24,7 @@ namespace xla {
StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
std::unique_ptr<HloModuleGroup> module_group, std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs, std::vector<std::vector<se::StreamExecutor*>> stream_execs,
se::DeviceMemoryAllocator* device_allocator) { const CompileOptions& options) {
// Tensorflow tries to enable the following behaviors in all its threads: // Tensorflow tries to enable the following behaviors in all its threads:
// //
// - Denormals are zero (DAZ): roughly, operations treat denormal floats as // - Denormals are zero (DAZ): roughly, operations treat denormal floats as
@ -48,10 +48,10 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
TF_ASSIGN_OR_RETURN(modules[i], TF_ASSIGN_OR_RETURN(modules[i],
RunHloPasses(std::move(modules[i]), stream_execs[i][0], RunHloPasses(std::move(modules[i]), stream_execs[i][0],
device_allocator)); options.device_allocator));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable, TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
RunBackend(std::move(modules[i]), stream_execs[i][0], RunBackend(std::move(modules[i]), stream_execs[i][0],
device_allocator)); options.device_allocator));
result.push_back(std::move(executable)); result.push_back(std::move(executable));
} }

View File

@ -66,13 +66,14 @@ class LLVMCompiler : public Compiler {
// std::unique_ptr<HloModule> module, // std::unique_ptr<HloModule> module,
// se::StreamExecutor* stream_exec, // se::StreamExecutor* stream_exec,
// se::DeviceMemoryAllocator* device_allocator) // se::DeviceMemoryAllocator* device_allocator)
using Compiler::Compile;
using Compiler::RunBackend; using Compiler::RunBackend;
using Compiler::RunHloPasses; using Compiler::RunHloPasses;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group, std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs, std::vector<std::vector<se::StreamExecutor*>> stream_execs,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
protected: protected:
ModuleHook user_pre_optimization_hook_; ModuleHook user_pre_optimization_hook_;

View File

@ -190,11 +190,12 @@ LocalService::CompileExecutables(
// single partition computations are built using `BuildExecutables`, fix it, // single partition computations are built using `BuildExecutables`, fix it,
// and remove this special case (provided the performance if similar). // and remove this special case (provided the performance if similar).
if (build_options.num_partitions() == 1) { if (build_options.num_partitions() == 1) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
std::unique_ptr<Executable> executable, BuildExecutable(proto, std::move(module_config),
BuildExecutable(proto, std::move(module_config), execute_backend_.get(), execute_backend_.get(), executor,
executor, build_options.device_allocator(), {build_options.device_allocator(),
build_options.run_backend_only())); build_options.compile_thread_pool()},
build_options.run_backend_only()));
std::vector<std::unique_ptr<Executable>> executables; std::vector<std::unique_ptr<Executable>> executables;
executables.push_back(std::move(executable)); executables.push_back(std::move(executable));
return executables; return executables;
@ -206,10 +207,12 @@ LocalService::CompileExecutables(
std::vector<se::StreamExecutor*> executors(build_options.num_partitions(), std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
executor); executor);
return BuildExecutables({&proto}, std::move(module_configs), return BuildExecutables(
execute_backend_.get(), {executors}, /*module_protos=*/{&proto}, std::move(module_configs),
build_options.device_allocator(), execute_backend_.get(), {executors},
build_options.run_backend_only()); Compiler::CompileOptions{build_options.device_allocator(),
build_options.compile_thread_pool()},
build_options.run_backend_only());
} }
} }

View File

@ -28,28 +28,24 @@ bool IsUnimplemented(StatusOr<T>& result) {
StatusOr<std::unique_ptr<HloModule>> FailoverCompiler::RunHloPasses( StatusOr<std::unique_ptr<HloModule>> FailoverCompiler::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) { const CompileOptions& options) {
auto result = auto result = primary_->RunHloPasses(module->Clone(), stream_exec, options);
primary_->RunHloPasses(module->Clone(), stream_exec, device_allocator);
if (IsUnimplemented(result)) { if (IsUnimplemented(result)) {
VLOG(2) << "RunHloPasses resulted in " << result.status() VLOG(2) << "RunHloPasses resulted in " << result.status()
<< ", falling back to secondary backend"; << ", falling back to secondary backend";
return secondary_->RunHloPasses(std::move(module), stream_exec, return secondary_->RunHloPasses(std::move(module), stream_exec, options);
device_allocator);
} }
return result; return result;
} }
StatusOr<std::unique_ptr<Executable>> FailoverCompiler::RunBackend( StatusOr<std::unique_ptr<Executable>> FailoverCompiler::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) { const CompileOptions& options) {
auto result = auto result = primary_->RunBackend(module->Clone(), stream_exec, options);
primary_->RunBackend(module->Clone(), stream_exec, device_allocator);
if (IsUnimplemented(result)) { if (IsUnimplemented(result)) {
VLOG(2) << "RunBackend resulted in " << result.status() VLOG(2) << "RunBackend resulted in " << result.status()
<< ", falling back to secondary backend"; << ", falling back to secondary backend";
return secondary_->RunBackend(std::move(module), stream_exec, return secondary_->RunBackend(std::move(module), stream_exec, options);
device_allocator);
} }
return result; return result;
} }
@ -57,7 +53,7 @@ StatusOr<std::unique_ptr<Executable>> FailoverCompiler::RunBackend(
StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
std::unique_ptr<HloModuleGroup> module_group, std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs, std::vector<std::vector<se::StreamExecutor*>> stream_execs,
se::DeviceMemoryAllocator* device_allocator) { const CompileOptions& options) {
std::vector<std::unique_ptr<Executable>> result; std::vector<std::unique_ptr<Executable>> result;
std::vector<std::unique_ptr<HloModule>> modules = std::vector<std::unique_ptr<HloModule>> modules =
module_group->ConsumeModules(); module_group->ConsumeModules();
@ -67,17 +63,15 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
return Unimplemented( return Unimplemented(
"Model partitioning not implemented for the failover compiler!"); "Model partitioning not implemented for the failover compiler!");
} }
auto executable = [stream_execs, device_allocator, i, auto executable = [stream_execs, &options, i,
this](std::unique_ptr<HloModule> module) this](std::unique_ptr<HloModule> module)
-> StatusOr<std::unique_ptr<Executable>> { -> StatusOr<std::unique_ptr<Executable>> {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(auto processed_module,
auto processed_module, primary_->RunHloPasses(std::move(module),
primary_->RunHloPasses(std::move(module), stream_execs[i][0], stream_execs[i][0], options));
device_allocator)); TF_ASSIGN_OR_RETURN(auto result,
TF_ASSIGN_OR_RETURN( primary_->RunBackend(std::move(processed_module),
auto result, stream_execs[i][0], options));
primary_->RunBackend(std::move(processed_module), stream_execs[i][0],
device_allocator));
return result; return result;
}(modules[i]->Clone()); }(modules[i]->Clone());
@ -85,13 +79,11 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
VLOG(2) << "Compile resulted in " << executable.status() VLOG(2) << "Compile resulted in " << executable.status()
<< ", falling back to secondary backend"; << ", falling back to secondary backend";
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
modules[i], modules[i], secondary_->RunHloPasses(std::move(modules[i]),
secondary_->RunHloPasses(std::move(modules[i]), stream_execs[i][0], stream_execs[i][0], options));
device_allocator)); TF_ASSIGN_OR_RETURN(executable,
TF_ASSIGN_OR_RETURN( secondary_->RunBackend(std::move(modules[i]),
executable, stream_execs[i][0], options));
secondary_->RunBackend(std::move(modules[i]), stream_execs[i][0],
device_allocator));
} }
if (!executable.ok()) { if (!executable.ok()) {

View File

@ -51,16 +51,16 @@ class FailoverCompiler final : public Compiler {
StatusOr<std::unique_ptr<HloModule>> RunHloPasses( StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
StatusOr<std::unique_ptr<Executable>> RunBackend( StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group, std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs, std::vector<std::vector<se::StreamExecutor*>> stream_execs,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,

View File

@ -87,16 +87,16 @@ class MlirCompilerImpl : public MlirCompiler {
public: public:
StatusOr<std::unique_ptr<HloModule>> RunHloPasses( StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
StatusOr<std::unique_ptr<Executable>> RunBackend( StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group, std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs, std::vector<std::vector<se::StreamExecutor*>> stream_execs,
se::DeviceMemoryAllocator* device_allocator) override; const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
@ -155,12 +155,12 @@ std::string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) {
StatusOr<std::unique_ptr<HloModule>> MlirCompilerImpl::RunHloPasses( StatusOr<std::unique_ptr<HloModule>> MlirCompilerImpl::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) { const CompileOptions& options) {
// Until we find a reason to do something different, run the same passes // Until we find a reason to do something different, run the same passes
// that the normal GPU backend runs. // that the normal GPU backend runs.
gpu::NVPTXCompiler xla_compiler; gpu::NVPTXCompiler xla_compiler;
TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec, TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec,
device_allocator)); options.device_allocator));
TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get())); TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get()));
return std::move(module); return std::move(module);
@ -454,7 +454,7 @@ StatusOr<std::unique_ptr<gpu::KernelThunk>> TransformKernelToXlaThunk(
StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend( StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) { const CompileOptions& options) {
// Determine the HLO schedule, which is an ordering of HLO instructions. This // Determine the HLO schedule, which is an ordering of HLO instructions. This
// is used by buffer assignment to enable buffer reuse, and the same ordering // is used by buffer assignment to enable buffer reuse, and the same ordering
// must also be used to determine the thunk launch schedule. // must also be used to determine the thunk launch schedule.
@ -595,7 +595,7 @@ StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(
std::unique_ptr<HloModuleGroup> module_group, std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs, std::vector<std::vector<se::StreamExecutor*>> stream_execs,
se::DeviceMemoryAllocator* device_allocator) { const CompileOptions& options) {
return Unimplemented("Not yet implemented in MLIR compiler"); return Unimplemented("Not yet implemented in MLIR compiler");
} }

View File

@ -357,7 +357,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
const std::vector<const HloModuleProto*>& module_protos, const std::vector<const HloModuleProto*>& module_protos,
std::vector<std::unique_ptr<HloModuleConfig>> module_configs, std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors, Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
se::DeviceMemoryAllocator* device_allocator, bool run_backend_only) { const Compiler::CompileOptions& options, bool run_backend_only) {
VLOG(1) << StrFormat("BuildExecutable on service %p", this); VLOG(1) << StrFormat("BuildExecutable on service %p", this);
// Dump computation proto state if flag is set. // Dump computation proto state if flag is set.
@ -387,17 +387,15 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
std::vector<std::unique_ptr<Executable>> executables; std::vector<std::unique_ptr<Executable>> executables;
if (!run_backend_only) { if (!run_backend_only) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(executables, backend->compiler()->Compile(
executables, std::move(module_group),
backend->compiler()->Compile(std::move(module_group), std::move(executors), options));
std::move(executors), device_allocator));
} else { } else {
auto modules = module_group->ConsumeModules(); auto modules = module_group->ConsumeModules();
for (std::unique_ptr<HloModule>& module : modules) { for (std::unique_ptr<HloModule>& module : modules) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
std::unique_ptr<Executable> executable, backend->compiler()->RunBackend(
backend->compiler()->RunBackend(std::move(module), executors[0][0], std::move(module), executors[0][0], options));
device_allocator));
executables.push_back(std::move(executable)); executables.push_back(std::move(executable));
} }
} }
@ -710,7 +708,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables, TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
BuildExecutables(module_protos, std::move(module_configs), BuildExecutables(module_protos, std::move(module_configs),
execute_backend_.get(), all_executors, execute_backend_.get(), all_executors,
/*device_allocator=*/nullptr)); {/*device_allocator=*/nullptr}));
std::vector<Executable*> executable_ptrs; std::vector<Executable*> executable_ptrs;
executable_ptrs.reserve(executables.size()); executable_ptrs.reserve(executables.size());
for (const auto& executable : executables) { for (const auto& executable : executables) {
@ -810,7 +808,7 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable( StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
const HloModuleProto& module_proto, const HloModuleProto& module_proto,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend, std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator, se::StreamExecutor* executor, const Compiler::CompileOptions& options,
bool run_backend_only) { bool run_backend_only) {
VLOG(1) << StrFormat( VLOG(1) << StrFormat(
"BuildExecutable on service %p with serialized module proto: %s", this, "BuildExecutable on service %p with serialized module proto: %s", this,
@ -822,14 +820,13 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
if (!run_backend_only) { if (!run_backend_only) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
module, backend->compiler()->RunHloPasses(std::move(module), executor, std::move(module), executor, options));
device_allocator));
} }
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable, TF_ASSIGN_OR_RETURN(
backend->compiler()->RunBackend( std::unique_ptr<Executable> executable,
std::move(module), executor, device_allocator)); backend->compiler()->RunBackend(std::move(module), executor, options));
const auto& debug_opts = module_config->debug_options(); const auto& debug_opts = module_config->debug_options();
if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) && if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) &&
@ -875,7 +872,7 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
BuildExecutable(arg->computation(), std::move(module_config), BuildExecutable(arg->computation(), std::move(module_config),
execute_backend_.get(), execute_backend_.get(),
execute_backend_->default_stream_executor(), execute_backend_->default_stream_executor(),
/*device_allocator=*/nullptr)); {/*device_allocator=*/nullptr}));
*result->mutable_handle() = compilation_cache_.Insert(std::move(executable)); *result->mutable_handle() = compilation_cache_.Insert(std::move(executable));

View File

@ -235,8 +235,7 @@ class Service : public ServiceInterface {
StatusOr<std::unique_ptr<Executable>> BuildExecutable( StatusOr<std::unique_ptr<Executable>> BuildExecutable(
const HloModuleProto& module_proto, const HloModuleProto& module_proto,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend, std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
se::StreamExecutor* executor, se::StreamExecutor* executor, const Compiler::CompileOptions& options,
se::DeviceMemoryAllocator* device_allocator = nullptr,
bool run_backend_only = false); bool run_backend_only = false);
// Same as BuildExecutable() above, but builds a list of Executables for the // Same as BuildExecutable() above, but builds a list of Executables for the
@ -245,8 +244,7 @@ class Service : public ServiceInterface {
const std::vector<const HloModuleProto*>& module_protos, const std::vector<const HloModuleProto*>& module_protos,
std::vector<std::unique_ptr<HloModuleConfig>> module_configs, std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors, Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
se::DeviceMemoryAllocator* device_allocator, const Compiler::CompileOptions& options, bool run_backend_only = false);
bool run_backend_only = false);
// Runs the given executable with the given arguments and register the result // Runs the given executable with the given arguments and register the result
// in the allocation tracker. The handle of the result from the tracker is // in the allocation tracker. The handle of the result from the tracker is

View File

@ -102,27 +102,8 @@ bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) {
} }
} }
// Returns a sharding where each tuple element is chosen as the more specific
// one of the corresponding elements in a and b. Requires a an b to have the
// same tuple nesting.
HloSharding MergeForMoreSpecificSharding(const HloSharding& a,
const HloSharding& b) {
if (a.IsTuple()) {
HloSharding result = a;
CHECK(b.IsTuple());
CHECK_EQ(a.tuple_elements().size(), b.tuple_elements().size());
for (int64 i = 0; i < result.tuple_elements().size(); ++i) {
result.tuple_elements()[i] = MergeForMoreSpecificSharding(
a.tuple_elements()[i], b.tuple_elements()[i]);
}
return result;
}
return IsShardingMoreSpecific(a, b) ? a : b;
}
// Tries to refine `to_merge` by combining with `old`. Returns if the final // Tries to refine `to_merge` by combining with `old`. Returns if the final
// `to_merge` is more specific than `old`. May combine partial sharding in // `to_merge` is more specific than `old`.
// addition to MergeForMoreSpecificSharding().
bool MergeSharding(const HloSharding& old, HloSharding* to_merge, bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
bool may_combine_partial_sharding) { bool may_combine_partial_sharding) {
if (old.IsTuple()) { if (old.IsTuple()) {
@ -1093,8 +1074,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
} }
auto sharding = instruction->operand(0)->sharding(); auto sharding = instruction->operand(0)->sharding();
if (instruction->has_sharding()) { if (instruction->has_sharding()) {
sharding = MergeSharding(instruction->sharding(), &sharding,
MergeForMoreSpecificSharding(sharding, instruction->sharding()); may_combine_partial_sharding);
} }
return MaybeImproveInstructionSharding(std::move(sharding), instruction, return MaybeImproveInstructionSharding(std::move(sharding), instruction,
may_combine_partial_sharding); may_combine_partial_sharding);
@ -1320,6 +1301,12 @@ absl::optional<HloSharding> GetShardingFromUser(
return hlo_sharding_util::ReshapeSharding( return hlo_sharding_util::ReshapeSharding(
user.shape(), instruction.shape(), user.sharding()); user.shape(), instruction.shape(), user.sharding());
} }
case HloOpcode::kPad: {
if (&instruction != user.operand(0)) {
return absl::nullopt;
}
return user.sharding();
}
case HloOpcode::kSlice: { case HloOpcode::kSlice: {
return user.sharding(); return user.sharding();
} }
@ -1673,8 +1660,10 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
// If instruction is a while, or the root or a parameter of a while body, // If instruction is a while, or the root or a parameter of a while body,
// then propagate its sharding to the while instruction, to its body root, // then propagate its sharding to the while instruction, to its body root,
// and to its condition parameter. // and to its condition parameter.
std::function<void(HloInstruction*)> maybe_computation_propagation = std::function<void(HloInstruction*, absl::flat_hash_set<HloInstruction*>*)>
[&](HloInstruction* instruction) { maybe_computation_propagation = [&](HloInstruction* instruction,
absl::flat_hash_set<HloInstruction*>*
changed) {
auto propagate_to_instruction = [&](HloInstruction* search_inst) { auto propagate_to_instruction = [&](HloInstruction* search_inst) {
auto related_instructions = get_related_instructions(search_inst); auto related_instructions = get_related_instructions(search_inst);
if (absl::c_count(related_instructions, instruction)) { if (absl::c_count(related_instructions, instruction)) {
@ -1683,7 +1672,8 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
inst->sharding() != instruction->sharding()) { inst->sharding() != instruction->sharding()) {
VLOG(2) << "Add computation sharding: " << inst->name(); VLOG(2) << "Add computation sharding: " << inst->name();
inst->set_sharding(instruction->sharding()); inst->set_sharding(instruction->sharding());
maybe_computation_propagation(inst); changed->insert(inst);
maybe_computation_propagation(inst, changed);
} }
} }
} }
@ -1785,6 +1775,14 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
for (const HloInstruction* instruction : instructions) { for (const HloInstruction* instruction : instructions) {
already_sharded_counter += (instruction->has_sharding() ? 1 : 0); already_sharded_counter += (instruction->has_sharding() ? 1 : 0);
} }
auto clear_cache = [&](HloInstruction* hlo) {
for (auto operand : hlo->operands()) {
already_inferred_from_users.erase(operand);
}
for (auto user : hlo->users()) {
already_inferred_from_operands.erase(user);
}
};
// First iterate the HLO graph in post order taking shardings from // First iterate the HLO graph in post order taking shardings from
// operands. // operands.
for (HloInstruction* instruction : instructions) { for (HloInstruction* instruction : instructions) {
@ -1799,12 +1797,11 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
any_changed = true; any_changed = true;
VLOG(2) << "Add sharding (forward-pass): " VLOG(2) << "Add sharding (forward-pass): "
<< instruction->ToString(); << instruction->ToString();
maybe_computation_propagation(instruction); absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
for (auto operand : instruction->operands()) { maybe_computation_propagation(instruction, &changed_in_comp_prop);
already_inferred_from_users.erase(operand); clear_cache(instruction);
} for (auto hlo : changed_in_comp_prop) {
for (auto user : instruction->users()) { clear_cache(hlo);
already_inferred_from_operands.erase(user);
} }
changed_last_iter = true; changed_last_iter = true;
} }
@ -1823,12 +1820,11 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
++inferred_from_user_counter; ++inferred_from_user_counter;
any_changed = true; any_changed = true;
VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString(); VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString();
maybe_computation_propagation(*it); absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
for (auto operand : (*it)->operands()) { maybe_computation_propagation(*it, &changed_in_comp_prop);
already_inferred_from_users.erase(operand); clear_cache(*it);
} for (auto hlo : changed_in_comp_prop) {
for (auto user : (*it)->users()) { clear_cache(hlo);
already_inferred_from_operands.erase(user);
} }
changed_last_iter = true; changed_last_iter = true;
} }

View File

@ -514,6 +514,26 @@ ENTRY %pad {
op::Sharding("{devices=[2,2]0,1,2,3}")); op::Sharding("{devices=[2,2]0,1,2,3}"));
} }
TEST_F(ShardingPropagationTest, PadBackwardPass) {
const char* const hlo_string = R"(
HloModule module
ENTRY %pad {
%input = f32[11,17]{1,0} parameter(0)
%copy = f32[11,17]{1,0} copy(%input)
%pad_value = f32[] parameter(1)
%pad = f32[27,51]{1,0} pad(%copy, %pad_value), padding=2_4_1x1_1_2,
sharding={devices=[2,2]0,1,2,3}
ROOT %result = f32[27,51]{1,0} copy(%pad)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
ShardingPropagation().Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "copy"),
op::Sharding("{devices=[2,2]0,1,2,3}"));
}
TEST_F(ShardingPropagationTest, PartialReplicatedPadForwardPass) { TEST_F(ShardingPropagationTest, PartialReplicatedPadForwardPass) {
const char* const hlo_string = R"( const char* const hlo_string = R"(
HloModule module HloModule module
@ -856,40 +876,41 @@ TEST_F(ShardingPropagationTest, While) {
HloModule module HloModule module
%cond { %cond {
%vars.cond = (u32[], f32[10]{0}) parameter(0) %vars.cond = (u32[], f32[10,10]) parameter(0)
%count.cond = u32[] get-tuple-element((u32[], f32[10]{0}) %vars.cond), index=0 %count.cond = u32[] get-tuple-element((u32[], f32[10,10]) %vars.cond), index=0
%limit = u32[] constant(10) %limit = u32[] constant(10)
ROOT %lt = pred[] compare(u32[] %count.cond, u32[] %limit), direction=LT ROOT %lt = pred[] compare(u32[] %count.cond, u32[] %limit), direction=LT
} }
%body { %body {
%vars = (u32[], f32[10]{0}) parameter(0) %vars = (u32[], f32[10,10]) parameter(0)
%count = u32[] get-tuple-element(%vars), index=0 %count = u32[] get-tuple-element(%vars), index=0
%acc = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %vars), index=1 %acc = f32[10,10] get-tuple-element((u32[], f32[10,10]) %vars), index=1
%one = u32[] constant(1) %one = u32[] constant(1)
%count.1 = u32[] add(u32[] %count, u32[] %one), sharding={replicated} %count.1 = u32[] add(u32[] %count, u32[] %one), sharding={replicated}
%acc.1 = f32[10]{0} add(f32[10]{0} %acc, f32[10]{0} %acc) %acc.1 = f32[10,10] add(f32[10,10] %acc, f32[10,10] %acc)
ROOT %tuple = (u32[], f32[10]{0}) tuple(u32[] %count.1, f32[10]{0} %acc.1) ROOT %tuple = (u32[], f32[10,10]) tuple(u32[] %count.1, f32[10,10] %acc.1)
} }
ENTRY %entry { ENTRY %entry {
%p0 = f32[10]{0} parameter(0) %p0 = f32[10,10] parameter(0)
%p0.copy = f32[10]{0} copy(f32[10]{0} %p0) %p0.copy = f32[10,10] copy(f32[10,10] %p0)
%p1 = f32[10]{0} parameter(1) %p1 = f32[10,10] parameter(1)
%zero = u32[] constant(0) %zero = u32[] constant(0)
%init = (u32[], f32[10]{0}) tuple(u32[] %zero, f32[10]{0} %p0.copy) %init = (u32[], f32[10,10]) tuple(u32[] %zero, f32[10,10] %p0.copy)
%while = (u32[], f32[10]{0}) while((u32[], f32[10]{0}) %init), %while = (u32[], f32[10,10]) while((u32[], f32[10,10]) %init),
body=%body, condition=%cond body=%body, condition=%cond
%res = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %while), index=1 %res = f32[10,10] get-tuple-element((u32[], f32[10,10]) %while), index=1
%prev = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %init), index=1 %prev = f32[10,10] get-tuple-element((u32[], f32[10,10]) %init), index=1
%res.1 = f32[10]{0} multiply(f32[10]{0} %res, %prev) %res.1 = f32[10,10] multiply(f32[10,10] %res, %prev)
ROOT %res_tuple = (f32[10]{0}) tuple(f32[10]{0} %res.1) ROOT %res_tuple = (f32[10,10]) tuple(f32[10,10] %res.1)
})"; })";
auto while_is_sharded = [this](HloModule* module, auto while_is_sharded = [this](HloModule* module,
const HloSharding& sharding) { const HloSharding& sharding) {
TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardingPropagation().Run(module)); TF_ASSERT_OK_AND_ASSIGN(bool changed,
ShardingPropagation(/*is_spmd=*/true).Run(module));
EXPECT_TRUE(changed); EXPECT_TRUE(changed);
auto while_instr = FindInstruction(module, "while"); auto while_instr = FindInstruction(module, "while");
EXPECT_NE(nullptr, while_instr); EXPECT_NE(nullptr, while_instr);
@ -911,7 +932,7 @@ ENTRY %entry {
auto body_root = FindInstruction(module.get(), "tuple"); auto body_root = FindInstruction(module.get(), "tuple");
EXPECT_NE(nullptr, body_root); EXPECT_NE(nullptr, body_root);
auto sharding = auto sharding =
ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie(); ParseSharding("{{replicated}, {devices=[2,1]0,1}}").ConsumeValueOrDie();
body_root->set_sharding(sharding); body_root->set_sharding(sharding);
while_is_sharded(module.get(), sharding); while_is_sharded(module.get(), sharding);
} }
@ -921,11 +942,30 @@ ENTRY %entry {
ParseAndReturnVerifiedModule(hlo_string)); ParseAndReturnVerifiedModule(hlo_string));
auto acc_1 = FindInstruction(module.get(), "acc.1"); auto acc_1 = FindInstruction(module.get(), "acc.1");
EXPECT_NE(nullptr, acc_1); EXPECT_NE(nullptr, acc_1);
acc_1->set_sharding(ParseSharding("{devices=[2]0,1}").ConsumeValueOrDie()); acc_1->set_sharding(
ParseSharding("{devices=[2,1]0,1}").ConsumeValueOrDie());
while_is_sharded( while_is_sharded(module.get(),
module.get(), ParseSharding("{{replicated}, {devices=[2,1]0,1}}")
ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie()); .ConsumeValueOrDie());
}
{
// Merge partial sharding from operand and body.
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto acc_1 = FindInstruction(module.get(), "acc.1");
EXPECT_NE(nullptr, acc_1);
acc_1->set_sharding(
ParseSharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")
.ConsumeValueOrDie());
auto p0 = FindInstruction(module.get(), "p0");
p0->set_sharding(
ParseSharding("{devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}")
.ConsumeValueOrDie());
while_is_sharded(module.get(),
ParseSharding("{{replicated}, {devices=[2,2]0,1,2,3}}")
.ConsumeValueOrDie());
} }
} }

View File

@ -182,6 +182,10 @@ class ConvolutionVisitor {
return permute_dims[id]; return permute_dims[id];
} }
int64 ReverseDimLookUp(absl::Span<const int64> permute_dims, int64 id) {
return std::distance(permute_dims.begin(), absl::c_find(permute_dims, id));
}
HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter( HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
HloInstruction* instr, int64 depth); HloInstruction* instr, int64 depth);
@ -215,9 +219,10 @@ class ConvolutionVisitor {
// Limit on batch size to apply this technique on. // Limit on batch size to apply this technique on.
int64 limit_on_batch_size_; int64 limit_on_batch_size_;
// We choose the new batch size to be a constant so that space-to-batch // We choose the new batch size to be kNumSplits times that of the old batch
// propagation through several convolutional layers is consistent. // so that space-to-batch propagation through several convolutional layers is
static constexpr int64 kNewBatchSize = 8; // consistent.
static constexpr int64 kNumSplits = 8;
// Depth for searching reduce window // Depth for searching reduce window
static constexpr int64 kReduceWindowSearchDepth = 10; static constexpr int64 kReduceWindowSearchDepth = 10;
@ -301,17 +306,12 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
if (old_batch_size > limit_on_batch_size_) { if (old_batch_size > limit_on_batch_size_) {
return false; return false;
} }
// We currently only cater to evenly divisible cases.
if (kNewBatchSize % old_batch_size != 0) {
return false;
}
VLOG(1) << "spatial size " << c.spatial_size; VLOG(1) << "spatial size " << c.spatial_size;
const int64 num_splits = kNewBatchSize / old_batch_size;
// If the ratio is not within the 2X range, we can't Halo Pad from the next // If the ratio is not within the 2X range, we can't Halo Pad from the next
// split. // split.
if (c.halo_size > CeilOfRatio(c.spatial_size, num_splits)) { if (c.halo_size > CeilOfRatio(c.spatial_size, kNumSplits)) {
return false; return false;
} }
VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString(); VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
@ -323,6 +323,24 @@ StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
int64 activations_batch_dim, int64 old_batch_size, int64 low_padding, int64 activations_batch_dim, int64 old_batch_size, int64 low_padding,
int64 high_padding, int64 halo_size, int64 original_split_dim_size, int64 high_padding, int64 halo_size, int64 original_split_dim_size,
HloInstruction* pad_val) { HloInstruction* pad_val) {
const int64 original_batch_size =
activations->shape().dimensions(activations_batch_dim) / kNumSplits;
if (original_batch_size > 1) {
std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
activations->shape().dimensions().end());
new_dimensions[activations_batch_dim] = kNumSplits;
new_dimensions.insert(new_dimensions.begin() + activations_batch_dim,
original_batch_size);
// Reshape the output of the new conv into the old convolutions shape.
TF_ASSIGN_OR_RETURN(activations,
MakeReshapeHlo(new_dimensions, activations));
spatial_dimension_to_split++;
activations_batch_dim++;
}
const int64 rank = activations->shape().rank(); const int64 rank = activations->shape().rank();
const int64 spatial_split_size = const int64 spatial_split_size =
activations->shape().dimensions(spatial_dimension_to_split); activations->shape().dimensions(spatial_dimension_to_split);
@ -415,6 +433,21 @@ StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
TF_ASSIGN_OR_RETURN(activations, MakeConcatHlo({activations, halo_region}, TF_ASSIGN_OR_RETURN(activations, MakeConcatHlo({activations, halo_region},
spatial_dimension_to_split)); spatial_dimension_to_split));
} }
if (original_batch_size > 1) {
std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
activations->shape().dimensions().end());
new_dimensions[activations_batch_dim] = original_batch_size * kNumSplits;
new_dimensions.erase(new_dimensions.begin() + activations_batch_dim - 1);
// Reshape the output of the new conv into the old convolutions shape.
TF_ASSIGN_OR_RETURN(activations,
MakeReshapeHlo(new_dimensions, activations));
spatial_dimension_to_split++;
activations_batch_dim++;
}
VLOG(1) << "HaloDuplicated activations " << activations->ToString(); VLOG(1) << "HaloDuplicated activations " << activations->ToString();
return activations; return activations;
} }
@ -424,17 +457,20 @@ ConvolutionVisitor::BringSpaceNextToBatch(
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
int64& spatial_dimension_to_split, int64& activations_batch_dim, int64& spatial_dimension_to_split, int64& activations_batch_dim,
bool is_backprop) { bool is_backprop) {
std::vector<int64> transpose_dims; std::vector<int64> transpose_dims(activations->shape().rank());
ConvolutionDimensionNumbers new_dim_numbers = dim_numbers; if (spatial_dimension_to_split == activations_batch_dim + 1) {
if (spatial_dimension_to_split != activations_batch_dim + 1) { absl::c_iota(transpose_dims, 0);
} else {
ConvolutionDimensionNumbers new_dim_numbers = dim_numbers;
int64 pushed_counter = 0; int64 pushed_counter = 0;
int64 new_batch_dim, new_spatial_dim; int64 new_batch_dim, new_spatial_dim;
int64 dim_counter = 0;
for (int i = 0; i < activations->shape().rank(); ++i) { for (int i = 0; i < activations->shape().rank(); ++i) {
if (i == activations_batch_dim) { if (i == activations_batch_dim) {
continue; continue;
} }
if (i == spatial_dimension_to_split) { if (i == spatial_dimension_to_split) {
transpose_dims.push_back(activations_batch_dim); transpose_dims[dim_counter++] = activations_batch_dim;
new_batch_dim = pushed_counter; new_batch_dim = pushed_counter;
pushed_counter++; pushed_counter++;
new_spatial_dim = pushed_counter; new_spatial_dim = pushed_counter;
@ -452,7 +488,7 @@ ConvolutionVisitor::BringSpaceNextToBatch(
} }
} }
} }
transpose_dims.push_back(i); transpose_dims[dim_counter++] = i;
pushed_counter++; pushed_counter++;
} }
@ -460,14 +496,14 @@ ConvolutionVisitor::BringSpaceNextToBatch(
spatial_dimension_to_split = new_spatial_dim; spatial_dimension_to_split = new_spatial_dim;
TF_ASSIGN_OR_RETURN(activations, TF_ASSIGN_OR_RETURN(activations,
MakeTransposeHlo(activations, transpose_dims)); MakeTransposeHlo(activations, transpose_dims));
}
if (is_backprop) { if (is_backprop) {
new_dim_numbers.set_input_feature_dimension(activations_batch_dim); new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
} else { } else {
new_dim_numbers.set_input_batch_dimension(activations_batch_dim); new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
}
dim_numbers = new_dim_numbers;
} }
dim_numbers = new_dim_numbers;
return SpaceNextToBatchDetails{activations, transpose_dims}; return SpaceNextToBatchDetails{activations, transpose_dims};
} }
@ -586,12 +622,23 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
VLOG(1) << "Checking if conv is supported for propagation " VLOG(1) << "Checking if conv is supported for propagation "
<< consumer->ToString(); << consumer->ToString();
if (IsConvSuitableForSpaceToBatch(consumer)) { if (IsConvSuitableForSpaceToBatch(consumer)) {
for (int64 i = 0; i < consumer->operand_count(); ++i) { if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) {
auto old_producer = consumer->mutable_operand(i); return false;
if (i == 0 && !old_to_new_instrs_.contains(old_producer)) {
return false;
}
} }
auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
// Make sure that the space dimension is the same across the producer
// and consumer.
if (consumer->convolution_dimension_numbers().input_spatial_dimensions(
get_chosen_spatial_dim(consumer)) != dim_map_val_op_0.second) {
return false;
}
// Make sure that the batch dimension is the same across the producer
// and consumer.
if (consumer->convolution_dimension_numbers().input_batch_dimension() !=
dim_map_val_op_0.first) {
return false;
}
return true; return true;
} }
@ -611,13 +658,35 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
VLOG(2) << "Checking for backprop filter conv operands " VLOG(2) << "Checking for backprop filter conv operands "
<< consumer->operand_count(); << consumer->operand_count();
if (!old_to_new_instrs_.contains(consumer->mutable_operand(1))) { auto activations = consumer->mutable_operand(0);
auto kernel = consumer->mutable_operand(1);
if (!old_to_new_instrs_.contains(kernel)) {
VLOG(2) << "Backprop filter conv not ready for propagation because of " VLOG(2) << "Backprop filter conv not ready for propagation because of "
"kernel is not space-to-batched"; "kernel is not space-to-batched";
return false; return false;
} }
if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) { if (!old_to_new_instrs_.contains(activations)) {
const int64 lhs_batch = activations->shape().dimensions(
consumer->convolution_dimension_numbers().input_feature_dimension());
auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)];
const int64 old_batch_dim = dim_map_val_op_1.first;
auto second_operand = old_to_new_instrs_[kernel];
auto permute_dims_second_operand =
instr_to_dim_permute_map_[second_operand];
const int64 new_batch_dim =
DimLookUp(permute_dims_second_operand, old_batch_dim);
const int64 rhs_batch = second_operand->shape().dimensions(new_batch_dim);
// Because we want to convert activations into a space-to-batched version
// only for backprop filter convolutions, we want to make sure that the
// batch dimensions (feature dimensions, technically) are same sized.
// Since RHS is already space-to-batched, we need to account for it too.
if (rhs_batch != kNumSplits * lhs_batch) {
return false;
}
// If activations have not been propagated through, we can do // If activations have not been propagated through, we can do
// space-to-batch on them provided kernel has been propagated. // space-to-batch on them provided kernel has been propagated.
VLOG(2) << "Backprop filter conv ready for propagation: kernel ready, " VLOG(2) << "Backprop filter conv ready for propagation: kernel ready, "
@ -625,10 +694,10 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
return true; return true;
} }
auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)]; auto first_operand = old_to_new_instrs_[activations];
auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)]; auto dim_map_val_op_0 = instr_to_dim_map_[activations];
auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)]; auto second_operand = old_to_new_instrs_[kernel];
auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)]; auto dim_map_val_op_1 = instr_to_dim_map_[kernel];
auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand]; auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
auto permute_dims_second_operand = auto permute_dims_second_operand =
@ -1119,7 +1188,7 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
Window new_win; Window new_win;
for (int64 i = 0; i < consumer->window().dimensions().size(); ++i) { for (int64 i = 0; i < consumer->window().dimensions().size(); ++i) {
auto dim = DimLookUp(permute_dims, i); auto dim = ReverseDimLookUp(permute_dims, i);
new_win.add_dimensions(); new_win.add_dimensions();
new_win.mutable_dimensions(i)->set_stride( new_win.mutable_dimensions(i)->set_stride(
consumer->window().dimensions(dim).stride()); consumer->window().dimensions(dim).stride());
@ -1339,7 +1408,9 @@ StatusOr<HloInstruction*> ConvolutionVisitor::SelectValidPortion(
const int64 new_space_size = new_shape.dimensions(new_space_dim); const int64 new_space_size = new_shape.dimensions(new_space_dim);
const int64 old_batch_size = old_shape.dimensions(old_batch_dim); const int64 old_batch_size = old_shape.dimensions(old_batch_dim);
const int64 old_space_size = old_shape.dimensions(old_space_dim); const int64 old_space_size = old_shape.dimensions(old_space_dim);
CHECK_EQ(new_batch_size % old_batch_size, 0); CHECK_EQ(new_batch_size % old_batch_size, 0)
<< " New batch size " << new_batch_size << " old batch size "
<< old_batch_size;
const int64 num_splits = new_batch_size / old_batch_size; const int64 num_splits = new_batch_size / old_batch_size;
// Build a constant PRED to decide which elements in the split dimension // Build a constant PRED to decide which elements in the split dimension
// are from halo. // are from halo.
@ -1394,8 +1465,10 @@ StatusOr<HloInstruction*> ConvolutionVisitor::BatchToSpace(
CHECK(old_to_new_instrs_.contains(old_instr)); CHECK(old_to_new_instrs_.contains(old_instr));
auto new_instr = old_to_new_instrs_[old_instr]; auto new_instr = old_to_new_instrs_[old_instr];
VLOG(2) << "old_batch_dim " << old_batch_dim << " old_space_dim " VLOG(2) << "old_batch_dim " << old_batch_dim << " old_space_dim "
<< old_space_dim << " new_instr " << new_instr->ToString() << old_space_dim << " old_instr " << old_instr->ToString()
<< " permute dims " << instr_to_dim_permute_map_.count(new_instr); << "\n new_instr " << new_instr->ToString() << " permute dims "
<< instr_to_dim_permute_map_.count(new_instr) << " old_batch_size "
<< old_batch_size;
CHECK(instr_to_dim_permute_map_.contains(new_instr)); CHECK(instr_to_dim_permute_map_.contains(new_instr));
auto permute_dims = instr_to_dim_permute_map_[new_instr]; auto permute_dims = instr_to_dim_permute_map_[new_instr];
const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim); const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
@ -1565,6 +1638,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
c.spatial_dimension_to_split, activations_batch_dim)); c.spatial_dimension_to_split, activations_batch_dim));
activations_new = retval.instr; activations_new = retval.instr;
std::vector<int64> trans_dims = retval.transpose_dims; std::vector<int64> trans_dims = retval.transpose_dims;
CHECK(!trans_dims.empty());
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant( auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(activations_new->shape().element_type()))); LiteralUtil::Zero(activations_new->shape().element_type())));
@ -1578,8 +1652,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
VLOG(1) << "spatial size " << c.spatial_size; VLOG(1) << "spatial size " << c.spatial_size;
const int64 num_splits = kNewBatchSize / old_batch_size; const int64 num_splits = kNumSplits;
const int64 output_offsets = convolution->shape().dimensions( const int64 output_offsets = convolution->shape().dimensions(
permuted_conv_dims_numbers.output_spatial_dimensions( permuted_conv_dims_numbers.output_spatial_dimensions(
get_chosen_spatial_dim(convolution))); get_chosen_spatial_dim(convolution)));
@ -1614,6 +1687,8 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
activations_new->shape().dimensions().end()); activations_new->shape().dimensions().end());
const int64 reshaped_space_size = const int64 reshaped_space_size =
new_space_size * new_batch_size / old_batch_size; new_space_size * new_batch_size / old_batch_size;
VLOG(3) << "Increasing the spatial size while propagating new_batch_size "
<< new_batch_size << " old_batch_size " << old_batch_size;
new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size; new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size;
new_dimensions[activations_batch_dim] = old_batch_size; new_dimensions[activations_batch_dim] = old_batch_size;
@ -1621,10 +1696,12 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations, TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
MakeReshapeHlo(new_dimensions, activations_new)); MakeReshapeHlo(new_dimensions, activations_new));
VLOG(3) << "First reshape done";
PaddingConfig padding_config = PaddingConfig padding_config =
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size()); MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
padding_config.mutable_dimensions(c.spatial_dimension_to_split) padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_high(spatial_split_size * new_batch_size - ->set_edge_padding_high(spatial_split_size * new_batch_size /
old_batch_size -
reshaped_space_size); reshaped_space_size);
padding_config.mutable_dimensions(c.spatial_dimension_to_split) padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_low(0); ->set_edge_padding_low(0);
@ -1647,6 +1724,8 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
reshaped_activations, reshaped_activations,
MakeReshapeHlo(reshape_back_dims, reshaped_activations)); MakeReshapeHlo(reshape_back_dims, reshaped_activations));
VLOG(3) << "Second reshape done";
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
activations_new, activations_new,
HaloDuplicateWithSlice( HaloDuplicateWithSlice(
@ -1664,6 +1743,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
// additional space available, and adjust the required slice size (and // additional space available, and adjust the required slice size (and
// thereby the halo size). // thereby the halo size).
if (spatial_split_size < new_space_size) { if (spatial_split_size < new_space_size) {
VLOG(3) << "Decreasing the spatial size while propagating";
const int64 additional_space_present = spatial_split_size % c.stride; const int64 additional_space_present = spatial_split_size % c.stride;
spatial_split_size = new_space_size; spatial_split_size = new_space_size;
slice_size = slice_size =
@ -1758,6 +1838,7 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations,
activations = retval.instr; activations = retval.instr;
std::vector<int64> transpose_dims = retval.transpose_dims; std::vector<int64> transpose_dims = retval.transpose_dims;
CHECK(!transpose_dims.empty());
// Because we are splitting the spatial dimension, if convolution needed // Because we are splitting the spatial dimension, if convolution needed
// padding in the spatial dimension, we materialize it. // padding in the spatial dimension, we materialize it.
if (high_padding || low_padding) { if (high_padding || low_padding) {
@ -1774,7 +1855,9 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations,
MakePadHlo(activations, padding, padding_config)); MakePadHlo(activations, padding, padding_config));
} }
VLOG(1) << "Initial padded activations shape " VLOG(1) << "Initial padded activations shape "
<< activations->shape().ToString(); << activations->shape().ToString() << " old_batch_size "
<< old_batch_size << " activations_batch_dim "
<< activations_batch_dim;
// Now we reorganize the activations. E.g. if the shape [B, SPACE] was [1, 16] // Now we reorganize the activations. E.g. if the shape [B, SPACE] was [1, 16]
// and 4 splits were needed, we first create [4, 4]. Next, to deal with halo // and 4 splits were needed, we first create [4, 4]. Next, to deal with halo
@ -1829,7 +1912,10 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
CHECK(old_to_new_instrs_.contains(kernel_old)); CHECK(old_to_new_instrs_.contains(kernel_old));
auto kernel_new = old_to_new_instrs_[kernel_old]; auto kernel_new = old_to_new_instrs_[kernel_old];
auto permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
HloInstruction* activations_new = nullptr; HloInstruction* activations_new = nullptr;
bool activations_locally_space_to_batched = false;
// If activations were no space-to-batched, we space-to-batch them below. // If activations were no space-to-batched, we space-to-batch them below.
if (!old_to_new_instrs_.contains(activations_old)) { if (!old_to_new_instrs_.contains(activations_old)) {
VLOG(1) << "Space-to-batching activations to enable space-to-depth"; VLOG(1) << "Space-to-batching activations to enable space-to-depth";
@ -1838,28 +1924,34 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
instr_to_dim_map_[activations_old] = instr_to_dim_map_[activations_old] =
std::make_pair(prev_feature_dim, prev_batch_dim); std::make_pair(prev_feature_dim, prev_batch_dim);
int64 activations_batch_dim = original_conv_dims.input_feature_dimension(); const int64 new_kernel_space_dim =
const int64 old_batch_size = DimLookUp(permute_dims_kernel, kernel_space_dim);
activations_old->shape().dimensions(activations_batch_dim);
const int64 num_splits = kNewBatchSize / old_batch_size;
const int64 new_kernel_split_dim_size = const int64 new_kernel_split_dim_size =
kernel_new->shape().dimensions(kernel_space_dim); kernel_new->shape().dimensions(new_kernel_space_dim);
const int64 needed_spatial_size = rhs_dilation * new_kernel_split_dim_size; const int64 needed_spatial_size = rhs_dilation * new_kernel_split_dim_size;
const int64 pad_size = const int64 pad_size =
needed_spatial_size * num_splits - old_split_dim_size; needed_spatial_size * kNumSplits - old_split_dim_size;
ConvolutionDimensionNumbers tmp_dim_numbers; ConvolutionDimensionNumbers tmp_dim_numbers;
tmp_dim_numbers = original_conv_dims; tmp_dim_numbers = original_conv_dims;
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto retval, auto retval,
SplitSpace(activations_old, tmp_dim_numbers, old_space_dim, SplitSpace(activations_old, tmp_dim_numbers, old_space_dim,
activations_batch_dim, old_batch_dim,
/*high_padding=*/pad_size, /*low_padding=*/0, /*high_padding=*/pad_size, /*low_padding=*/0,
needed_spatial_size, num_splits, /*is_backprop=*/true)); needed_spatial_size, kNumSplits, /*is_backprop=*/true));
old_to_new_instrs_[activations_old] = retval.first; old_to_new_instrs_[activations_old] = retval.first;
instr_to_dim_permute_map_[retval.first] = retval.second;
VLOG(3) << "Edited conv dims " << original_conv_dims.DebugString(); std::vector<int64> reversed_transpose_dims(retval.second.size());
for (int64 i = 0; i < retval.second.size(); ++i) {
reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
}
instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims;
VLOG(3) << "New Activations " << retval.first->ToString();
activations_locally_space_to_batched = true;
} }
CHECK(old_to_new_instrs_.contains(activations_old)); CHECK(old_to_new_instrs_.contains(activations_old));
@ -1884,7 +1976,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
i, DimLookUp(permute_dims, i, DimLookUp(permute_dims,
original_conv_dims.input_spatial_dimensions(i))); original_conv_dims.input_spatial_dimensions(i)));
permuted_conv_dims_numbers.set_kernel_spatial_dimensions( permuted_conv_dims_numbers.set_kernel_spatial_dimensions(
i, DimLookUp(permute_dims, i, DimLookUp(permute_dims_kernel,
original_conv_dims.kernel_spatial_dimensions(i))); original_conv_dims.kernel_spatial_dimensions(i)));
} }
@ -1905,10 +1997,11 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
previous_spatial_dim_count, previous_chosen_spatial_dim_in_output); previous_spatial_dim_count, previous_chosen_spatial_dim_in_output);
const int64 kernel_input_feature_dim = DimLookUp( const int64 kernel_input_feature_dim = DimLookUp(
permute_dims, original_conv_dims.kernel_input_feature_dimension()); permute_dims_kernel, original_conv_dims.kernel_input_feature_dimension());
const int64 kernel_output_feature_dim = DimLookUp( const int64 kernel_output_feature_dim =
permute_dims, original_conv_dims.kernel_output_feature_dimension()); DimLookUp(permute_dims_kernel,
original_conv_dims.kernel_output_feature_dimension());
permuted_conv_dims_numbers.set_kernel_input_feature_dimension( permuted_conv_dims_numbers.set_kernel_input_feature_dimension(
kernel_input_feature_dim); kernel_input_feature_dim);
@ -1931,7 +2024,8 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
VLOG(1) << "Propagating on conv activations_batch_dim " VLOG(1) << "Propagating on conv activations_batch_dim "
<< activations_batch_dim << " spatial_dimension_to_split " << activations_batch_dim << " spatial_dimension_to_split "
<< spatial_dimension_to_split << " old_batch_size " << old_batch_size; << spatial_dimension_to_split << " old_batch_size " << old_batch_size
<< " new_split_dim_size " << new_split_dim_size;
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto retval, auto retval,
@ -1939,6 +2033,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
spatial_dimension_to_split, activations_batch_dim, spatial_dimension_to_split, activations_batch_dim,
/*is_backprop=*/true)); /*is_backprop=*/true));
std::vector<int64> transpose_dims = retval.transpose_dims; std::vector<int64> transpose_dims = retval.transpose_dims;
CHECK(!transpose_dims.empty());
activations_new = retval.instr; activations_new = retval.instr;
VLOG(1) << "Activations_new post BringSpaceNextToBatch " VLOG(1) << "Activations_new post BringSpaceNextToBatch "
@ -1949,13 +2044,15 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant( auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(activations_new->shape().element_type()))); LiteralUtil::Zero(activations_new->shape().element_type())));
// Select activations correctly by masking additional space. if (!activations_locally_space_to_batched) {
TF_ASSIGN_OR_RETURN( // Select activations correctly by masking additional space.
activations_new, TF_ASSIGN_OR_RETURN(
SelectValidPortion(activations_new, activations_old, select_val, activations_new,
activations_batch_dim, spatial_dimension_to_split, SelectValidPortion(activations_new, activations_old, select_val,
old_batch_dim, old_space_dim)); activations_batch_dim, spatial_dimension_to_split,
old_batch_dim, old_space_dim));
}
VLOG(3) << "Selecting the valid kernel area";
// Select kernel correctly by masking additional space. // Select kernel correctly by masking additional space.
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
kernel_new, kernel_new,
@ -2238,7 +2335,6 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
VLOG(1) << "spatial size " << c.spatial_size; VLOG(1) << "spatial size " << c.spatial_size;
const int64 num_splits = kNewBatchSize / old_batch_size;
auto original_conv = convolution; auto original_conv = convolution;
const int64 output_spatial_dim = dim_numbers.output_spatial_dimensions( const int64 output_spatial_dim = dim_numbers.output_spatial_dimensions(
@ -2246,13 +2342,13 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
const int64 output_offsets = const int64 output_offsets =
convolution->shape().dimensions(output_spatial_dim); convolution->shape().dimensions(output_spatial_dim);
const int64 output_offsets_per_split = const int64 output_offsets_per_split =
CeilOfRatio(output_offsets, num_splits); CeilOfRatio(output_offsets, kNumSplits);
int64 spatial_split_size = int64 spatial_split_size =
CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride; CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
// Keep increasing the split size so that overall size isn't smaller than the // Keep increasing the split size so that overall size isn't smaller than the
// original spatial dimension. // original spatial dimension.
while (spatial_split_size * num_splits - c.spatial_size < 0) { while (spatial_split_size * kNumSplits - c.spatial_size < 0) {
spatial_split_size += c.stride; spatial_split_size += c.stride;
} }
@ -2276,12 +2372,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
const int64 slice_size = spatial_split_size + c.halo_size; const int64 slice_size = spatial_split_size + c.halo_size;
// Pad spatial dim. // Pad spatial dim.
const int64 pad_size = spatial_split_size * num_splits - c.spatial_size; const int64 pad_size = spatial_split_size * kNumSplits - c.spatial_size;
VLOG(1) << "spatial_split_size " << spatial_split_size << " stride " VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
<< c.stride << " slice_size " << slice_size; << c.stride << " slice_size " << slice_size;
VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split
<< " num_splits " << num_splits << " kernel_spatial_dim_size " << " num_splits " << kNumSplits << " kernel_spatial_dim_size "
<< c.kernel_spatial_dim_size; << c.kernel_spatial_dim_size;
int64 spatial_dimension_to_split = c.spatial_dimension_to_split; int64 spatial_dimension_to_split = c.spatial_dimension_to_split;
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
@ -2292,7 +2388,7 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
/*low_padding=*/c.base_dilation_factor == 1 /*low_padding=*/c.base_dilation_factor == 1
? c.inherent_low_padding ? c.inherent_low_padding
: 0, : 0,
spatial_split_size, num_splits)); spatial_split_size, kNumSplits));
HloInstruction* batch_increased_reshape = retval.first; HloInstruction* batch_increased_reshape = retval.first;
convolution->SetupDerivedInstruction(batch_increased_reshape); convolution->SetupDerivedInstruction(batch_increased_reshape);

View File

@ -317,5 +317,27 @@ TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) {
} }
} }
TEST_F(DynamismInferenceTest, InferThroughPad) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
// Test the analysis on a gather.
auto operand1 = ConstantR1<int32>(&b, {1, 2});
auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
PaddingConfig padding_config;
padding_config.add_dimensions()->set_edge_padding_high(1);
// After pad the value is [constant, constant, parameter].
auto pad = Pad(operand1, parameter, padding_config);
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
// Everything is constant, result is also contant.
EXPECT_FALSE(
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({0}));
EXPECT_FALSE(
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({1}));
EXPECT_TRUE(
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({2}));
}
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -57,7 +57,8 @@ class GpuDummyCompiler : public GpuCompiler {
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary( StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
const HloModule* hlo_module, llvm::Module* llvm_module, const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, se::StreamExecutor* stream_exec) { GpuVersion gpu_version, se::StreamExecutor* stream_exec,
bool relocatable) {
if (user_post_optimization_hook_) { if (user_post_optimization_hook_) {
user_post_optimization_hook_(*llvm_module); user_post_optimization_hook_(*llvm_module);
} }

View File

@ -20,9 +20,8 @@ op {
"A `Tensor` of type T. An alias of `x`. The content " "A `Tensor` of type T. An alias of `x`. The content "
"of `y` is undefined if there are duplicates in `i`." "of `y` is undefined if there are duplicates in `i`."
} }
summary: <<END summary: "Adds v into specified rows of x."
Adds v into specified rows of x. description: <<END
Computes y = x; y[i, :] += v; return y. Computes y = x; y[i, :] += v; return y.
END END
} }

View File

@ -1,8 +1,8 @@
op { op {
graph_op_name: "TopKUnique" graph_op_name: "TopKUnique"
summary: "Returns the TopK unique values in the array in sorted order. The" summary: "Returns the TopK unique values in the array in sorted order."
description: <<END description: <<END
running time is proportional to the product of K and the input The running time is proportional to the product of K and the input
size. Sorting the whole array is more efficient for sufficiently large size. Sorting the whole array is more efficient for sufficiently large
values of K. The median-of-medians algorithm is probably faster, but values of K. The median-of-medians algorithm is probably faster, but
difficult to implement efficiently in XLA. If there are fewer than K difficult to implement efficiently in XLA. If there are fewer than K

View File

@ -1,10 +1,11 @@
op { op {
graph_op_name: "TopKWithUnique" graph_op_name: "TopKWithUnique"
summary: "Returns the TopK values in the array in sorted order. This is a combination" summary: "Returns the TopK values in the array in sorted order."
description: <<END description: <<END
of MakeUnique and TopKUnique. The returned top-K will have its lower bits This is a combination of MakeUnique and TopKUnique. The returned top-K will
replaced by iota, thus it will be close to the original value but not exactly have its lower bits replaced by iota, thus it will be close to the original
the same. The running time is proportional to the product of K and the input value but not exactly the same. The running time is proportional to the product
size. NaNs are never returned. Subnormal numbers are flushed to zero. of K and the input size. NaNs are never returned. Subnormal numbers are flushed
to zero.
END END
} }

View File

@ -705,10 +705,10 @@ Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
return Status::OK(); return Status::OK();
} }
Status EagerContext::AddFunctionDefWithDebugInfo( Status EagerContext::AddFunctionDefWithStackTraces(
const FunctionDef& fdef, const Graph* graph_with_debug_info) { const FunctionDef& fdef, const StackTracesMap& stack_traces) {
return AddFunctionDef(fdef, FunctionDefLibrary(), return AddFunctionDef(fdef, FunctionDefLibrary(),
/* add_to_local_only=*/false, graph_with_debug_info); /* add_to_local_only=*/false, stack_traces);
} }
Status EagerContext::AddFunctionDef(const FunctionDef& fdef) { Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
@ -719,7 +719,7 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
Status EagerContext::AddFunctionDef(const FunctionDef& fdef, Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
const FunctionDefLibrary& library, const FunctionDefLibrary& library,
const bool add_to_local_only, const bool add_to_local_only,
const Graph* graph_with_debug_info) { const StackTracesMap& stack_traces) {
bool is_first_ref = false; bool is_first_ref = false;
{ {
mutex_lock l(cache_mu_); mutex_lock l(cache_mu_);
@ -753,8 +753,7 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
is_first_ref = registered_function->RefCountIsOne(); is_first_ref = registered_function->RefCountIsOne();
} }
if (is_first_ref) { if (is_first_ref) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef, stack_traces));
func_lib_def_.AddFunctionDef(fdef, graph_with_debug_info));
TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library)); TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
if (!add_to_local_only) { if (!add_to_local_only) {
return MaybeRegisterFunctionRemotely(fdef); return MaybeRegisterFunctionRemotely(fdef);

Some files were not shown because too many files have changed in this diff Show More