Merge remote-tracking branch 'upstream/master' into xtensa-fusion-f1
This commit is contained in:
commit
6b24408403
@ -61,6 +61,7 @@
|
|||||||
* Added support for saved model's session initializer through
|
* Added support for saved model's session initializer through
|
||||||
`TFLiteConverter.from_saved_model`.
|
`TFLiteConverter.from_saved_model`.
|
||||||
* Added dynamic range quantization support for the BatchMatMul op.
|
* Added dynamic range quantization support for the BatchMatMul op.
|
||||||
|
* Added DEPTH_TO_SPACE support in Post training quantization.
|
||||||
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
|
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
|
||||||
only supports float32 input.
|
only supports float32 input.
|
||||||
* TFLite Supports SingatureDef:
|
* TFLite Supports SingatureDef:
|
||||||
|
@ -145,6 +145,8 @@ if _running_from_pip_package():
|
|||||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||||
if _os.path.exists(_plugin_dir):
|
if _os.path.exists(_plugin_dir):
|
||||||
_ll.load_library(_plugin_dir)
|
_ll.load_library(_plugin_dir)
|
||||||
|
# Load Pluggable Device Library
|
||||||
|
_ll.load_pluggable_device_library(_plugin_dir)
|
||||||
|
|
||||||
# Add module aliases
|
# Add module aliases
|
||||||
if hasattr(_current_module, 'keras'):
|
if hasattr(_current_module, 'keras'):
|
||||||
|
@ -155,6 +155,8 @@ if _running_from_pip_package():
|
|||||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||||
if _os.path.exists(_plugin_dir):
|
if _os.path.exists(_plugin_dir):
|
||||||
_ll.load_library(_plugin_dir)
|
_ll.load_library(_plugin_dir)
|
||||||
|
# Load Pluggable Device Library
|
||||||
|
_ll.load_pluggable_device_library(_plugin_dir)
|
||||||
|
|
||||||
# Delete modules that should be hidden from dir().
|
# Delete modules that should be hidden from dir().
|
||||||
# Don't fail if these modules are not available.
|
# Don't fail if these modules are not available.
|
||||||
|
@ -684,7 +684,10 @@ tf_cc_test(
|
|||||||
name = "c_api_experimental_test",
|
name = "c_api_experimental_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["c_api_experimental_test.cc"],
|
srcs = ["c_api_experimental_test.cc"],
|
||||||
data = ["testdata/tf_record"],
|
data = [
|
||||||
|
"testdata/tf_record",
|
||||||
|
"//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
|
||||||
|
],
|
||||||
linkopts = select({
|
linkopts = select({
|
||||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
@ -704,6 +707,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform:resource_loader",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -37,7 +37,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/platform/blocking_counter.h"
|
#include "tensorflow/core/platform/blocking_counter.h"
|
||||||
#include "tensorflow/core/platform/casts.h"
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/net.h"
|
#include "tensorflow/core/platform/net.h"
|
||||||
#include "tensorflow/core/platform/platform.h"
|
#include "tensorflow/core/platform/platform.h"
|
||||||
#include "tensorflow/core/platform/strcat.h"
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
@ -630,6 +632,9 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||||
|
|
||||||
|
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
|
||||||
|
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||||
@ -743,3 +748,45 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
|||||||
TF_ImportGraphDefOptions* opts, unsigned char enable) {
|
TF_ImportGraphDefOptions* opts, unsigned char enable) {
|
||||||
opts->opts.validate_colocation_constraints = enable;
|
opts->opts.validate_colocation_constraints = enable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load a Pluggable Device library.
|
||||||
|
// On success, returns the handle to library in result and return OK from the
|
||||||
|
// function. Otherwise return nullptr in result and error Status from the
|
||||||
|
// function.
|
||||||
|
//
|
||||||
|
// If `library_filename` has already been loaded, we return a cached handle.
|
||||||
|
// Device and Kernels/Ops are registered as globals when a library is loaded
|
||||||
|
// for the first time.
|
||||||
|
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
||||||
|
TF_Status* status) {
|
||||||
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||||
|
status->status = tensorflow::errors::Unimplemented(
|
||||||
|
"PluggableDevice plugin functionality is not supported on mobile");
|
||||||
|
return nullptr;
|
||||||
|
#else
|
||||||
|
TF_Library* lib_handle = new TF_Library;
|
||||||
|
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||||
|
static std::unordered_map<std::string, void*>* loaded_libs =
|
||||||
|
new std::unordered_map<std::string, void*>();
|
||||||
|
tensorflow::Env* env = tensorflow::Env::Default();
|
||||||
|
{
|
||||||
|
tensorflow::mutex_lock lock(mu);
|
||||||
|
auto it = loaded_libs->find(library_filename);
|
||||||
|
if (it != loaded_libs->end()) {
|
||||||
|
lib_handle->lib_handle = it->second;
|
||||||
|
} else {
|
||||||
|
status->status =
|
||||||
|
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
delete lib_handle;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lib_handle;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
|
||||||
|
delete lib_handle;
|
||||||
|
}
|
||||||
|
@ -304,6 +304,27 @@ TF_CAPI_EXPORT extern void
|
|||||||
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||||
TF_ImportGraphDefOptions* opts, unsigned char enable);
|
TF_ImportGraphDefOptions* opts, unsigned char enable);
|
||||||
|
|
||||||
|
// Load the library specified by library_filename and register the pluggable
|
||||||
|
// device and related kernels present in that library. This function is not
|
||||||
|
// supported on embedded on mobile and embedded platforms and will fail if
|
||||||
|
// called.
|
||||||
|
//
|
||||||
|
// Pass "library_filename" to a platform-specific mechanism for dynamically
|
||||||
|
// loading a library. The rules for determining the exact location of the
|
||||||
|
// library are platform-specific and are not documented here.
|
||||||
|
//
|
||||||
|
// On success, returns the newly created library handle and places OK in status.
|
||||||
|
// The caller owns the library handle.
|
||||||
|
//
|
||||||
|
// On failure, returns nullptr and places an error status in status.
|
||||||
|
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
|
||||||
|
const char* library_filename, TF_Status* status);
|
||||||
|
|
||||||
|
// Frees the memory associated with the library handle.
|
||||||
|
// Does NOT unload the library.
|
||||||
|
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
|
||||||
|
TF_Library* lib_handle);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/resource_loader.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||||
|
|
||||||
@ -234,5 +235,22 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
|
|||||||
TF_DeleteTensor(tensor_1X6);
|
TF_DeleteTensor(tensor_1X6);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
|
||||||
|
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||||
|
// Load the library.
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
string lib_path =
|
||||||
|
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
|
||||||
|
"tensorflow", "c", "experimental", "stream_executor", "test",
|
||||||
|
"test_pluggable_device.so"));
|
||||||
|
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
|
||||||
|
TF_Code code = TF_GetCode(status);
|
||||||
|
string status_msg(TF_Message(status));
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
ASSERT_EQ(TF_OK, code) << status_msg;
|
||||||
|
TF_DeletePluggableDeviceLibraryHandle(lib);
|
||||||
|
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -213,7 +213,11 @@ TF_Function* TF_GraphToFunctionWithControlOutputs(
|
|||||||
TF_DeleteFunction(tf_function);
|
TF_DeleteFunction(tf_function);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tf_function->graph_with_debug_info = &fn_body->graph;
|
|
||||||
|
for (const Node* n : fn_body->graph.nodes()) {
|
||||||
|
tf_function->stack_traces[n->name()] = n->GetStackTrace();
|
||||||
|
}
|
||||||
|
|
||||||
return tf_function;
|
return tf_function;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,9 +157,7 @@ struct TF_DeviceList {
|
|||||||
|
|
||||||
struct TF_Function {
|
struct TF_Function {
|
||||||
tensorflow::FunctionDef fdef;
|
tensorflow::FunctionDef fdef;
|
||||||
|
tensorflow::StackTracesMap stack_traces;
|
||||||
// Graph with nodes with debug stack traces.
|
|
||||||
const tensorflow::Graph* graph_with_debug_info = nullptr;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TF_ApiDefMap {
|
struct TF_ApiDefMap {
|
||||||
|
@ -749,8 +749,8 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
|||||||
|
|
||||||
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithDebugInfo(
|
status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
|
||||||
function->fdef, function->graph_with_debug_info);
|
function->fdef, function->stack_traces);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
||||||
|
@ -111,11 +111,11 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||||||
// already exists.
|
// already exists.
|
||||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||||
|
|
||||||
// Same as `AddFunctionDef`, and additionally saves a pointer to the Graph
|
// Same as `AddFunctionDef`, but additionally saves the `stack_traces` under
|
||||||
// which has nodes containing stack traces for the nodes in `fdef`. Assumes
|
// the key of the function definition name (to be retrieved during function
|
||||||
// `graph` is alive while the function is alive.
|
// instantiation).
|
||||||
virtual Status AddFunctionDefWithDebugInfo(const FunctionDef& fdef,
|
virtual Status AddFunctionDefWithStackTraces(
|
||||||
const Graph* graph) = 0;
|
const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0;
|
||||||
|
|
||||||
// Find and return a added function by its name.
|
// Find and return a added function by its name.
|
||||||
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
|
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
|
||||||
|
17
tensorflow/c/experimental/stream_executor/test/BUILD
Normal file
17
tensorflow/c/experimental/stream_executor/test/BUILD
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Description:
|
||||||
|
# test for stream_executor
|
||||||
|
load(
|
||||||
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"tf_cc_shared_object",
|
||||||
|
)
|
||||||
|
|
||||||
|
package(
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_shared_object(
|
||||||
|
name = "test_pluggable_device.so",
|
||||||
|
srcs = ["test_pluggable_device.cc"],
|
||||||
|
visibility = ["//tensorflow/c:__subpackages__"],
|
||||||
|
deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
|
||||||
|
)
|
@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -13,16 +13,11 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
|
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
|
|
||||||
|
|
||||||
#include "pybind11/pybind11.h"
|
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
TF_Status* const status) {
|
||||||
|
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
|
||||||
namespace xla {
|
params->platform->name = "GPU";
|
||||||
|
params->platform->type = "XGPU";
|
||||||
xla::StatusOr<pybind11::object> Bfloat16Dtype();
|
}
|
||||||
|
|
||||||
} // namespace xla
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/stream.h"
|
#include "tensorflow/stream_executor/stream.h"
|
||||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||||
|
|
||||||
|
using tensorflow::errors::InvalidArgument;
|
||||||
// This file forms the basis of a stable ABI for third-party kernel
|
// This file forms the basis of a stable ABI for third-party kernel
|
||||||
// implementations. It is crucial that changes to this file are made cautiously
|
// implementations. It is crucial that changes to this file are made cautiously
|
||||||
// and with a focus on maintaining both source and binary compatibility.
|
// and with a focus on maintaining both source and binary compatibility.
|
||||||
@ -87,9 +88,25 @@ void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
|
|||||||
TF_SetStatus(status, TF_OK, "");
|
TF_SetStatus(status, TF_OK, "");
|
||||||
}
|
}
|
||||||
#undef CASE
|
#undef CASE
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
|
||||||
|
const char* attr_name,
|
||||||
|
TF_Status* status) {
|
||||||
|
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
|
||||||
|
const tensorflow::AttrValue* attr =
|
||||||
|
::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
|
||||||
|
if (attr == nullptr) {
|
||||||
|
status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
|
||||||
|
"' has no attr named '", attr_name, "'.");
|
||||||
|
}
|
||||||
|
return attr;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
|
void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
|
||||||
const char* attr_name,
|
const char* attr_name,
|
||||||
const TF_DataType type,
|
const TF_DataType type,
|
||||||
@ -257,7 +274,81 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
|
|||||||
cc_ctx->CtxFailure(s);
|
cc_ctx->CtxFailure(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DEFINE_TF_GETATTR(func, c_type, cc_type) \
|
void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
|
||||||
|
const char* attr_name,
|
||||||
|
int32_t* list_size,
|
||||||
|
int32_t* total_size,
|
||||||
|
TF_Status* status) {
|
||||||
|
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
*list_size = -1;
|
||||||
|
*total_size = -1;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
switch (attr->value_case()) {
|
||||||
|
#define SINGLE_CASE(kK, attr_type, size_expr) \
|
||||||
|
case tensorflow::AttrValue::kK: \
|
||||||
|
*list_size = -1; \
|
||||||
|
*total_size = size_expr; \
|
||||||
|
break;
|
||||||
|
|
||||||
|
SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
|
||||||
|
SINGLE_CASE(kI, TF_ATTR_INT, -1);
|
||||||
|
SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
|
||||||
|
SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
|
||||||
|
SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
|
||||||
|
SINGLE_CASE(kShape, TF_ATTR_SHAPE,
|
||||||
|
attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
|
||||||
|
SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
|
||||||
|
#undef SINGLE_CASE
|
||||||
|
|
||||||
|
case tensorflow::AttrValue::kList:
|
||||||
|
*list_size = 0;
|
||||||
|
*total_size = -1;
|
||||||
|
#define LIST_CASE(field, attr_type, ...) \
|
||||||
|
if (attr->list().field##_size() > 0) { \
|
||||||
|
*list_size = attr->list().field##_size(); \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
|
||||||
|
LIST_CASE(
|
||||||
|
s, TF_ATTR_STRING, *total_size = 0;
|
||||||
|
for (int i = 0; i < attr->list().s_size();
|
||||||
|
++i) { *total_size += attr->list().s(i).size(); });
|
||||||
|
LIST_CASE(i, TF_ATTR_INT);
|
||||||
|
LIST_CASE(f, TF_ATTR_FLOAT);
|
||||||
|
LIST_CASE(b, TF_ATTR_BOOL);
|
||||||
|
LIST_CASE(type, TF_ATTR_TYPE);
|
||||||
|
LIST_CASE(
|
||||||
|
shape, TF_ATTR_SHAPE, *total_size = 0;
|
||||||
|
for (int i = 0; i < attr->list().shape_size(); ++i) {
|
||||||
|
const auto& s = attr->list().shape(i);
|
||||||
|
*total_size += s.unknown_rank() ? 0 : s.dim_size();
|
||||||
|
});
|
||||||
|
LIST_CASE(tensor, TF_ATTR_TENSOR);
|
||||||
|
LIST_CASE(tensor, TF_ATTR_FUNC);
|
||||||
|
#undef LIST_CASE
|
||||||
|
break;
|
||||||
|
|
||||||
|
case tensorflow::AttrValue::kPlaceholder:
|
||||||
|
*list_size = -1;
|
||||||
|
*total_size = -1;
|
||||||
|
break;
|
||||||
|
|
||||||
|
case tensorflow::AttrValue::kFunc:
|
||||||
|
*list_size = -1;
|
||||||
|
*total_size = -1;
|
||||||
|
break;
|
||||||
|
|
||||||
|
case tensorflow::AttrValue::VALUE_NOT_SET:
|
||||||
|
status->status =
|
||||||
|
InvalidArgument("Attribute '", attr_name, "' has no value set");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \
|
||||||
void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
|
void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
|
||||||
const char* attr_name, \
|
const char* attr_name, \
|
||||||
c_type* val, TF_Status* status) { \
|
c_type* val, TF_Status* status) { \
|
||||||
@ -269,10 +360,84 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
|
|||||||
if (s.ok()) { \
|
if (s.ok()) { \
|
||||||
*val = static_cast<c_type>(v); \
|
*val = static_cast<c_type>(v); \
|
||||||
} \
|
} \
|
||||||
|
} \
|
||||||
|
void TF_OpKernelConstruction_GetAttr##func##List( \
|
||||||
|
TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \
|
||||||
|
int max_vals, TF_Status* status) { \
|
||||||
|
TF_SetStatus(status, TF_OK, ""); \
|
||||||
|
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \
|
||||||
|
if (!status->status.ok()) return; \
|
||||||
|
if (attr->value_case() != tensorflow::AttrValue::kList) { \
|
||||||
|
status->status = \
|
||||||
|
InvalidArgument("Value for '", attr_name, "' is not a list."); \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
status->status = \
|
||||||
|
tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \
|
||||||
|
if (!status->status.ok()) return; \
|
||||||
|
const auto len = std::min(max_vals, attr->list().list_field##_size()); \
|
||||||
|
for (int i = 0; i < len; ++i) { \
|
||||||
|
vals[i] = static_cast<c_type>(attr->list().list_field(i)); \
|
||||||
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
|
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
|
||||||
DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t)
|
DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i)
|
||||||
|
DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i)
|
||||||
|
DEFINE_TF_GETATTR(Float, float, float, "float", f)
|
||||||
|
DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
|
||||||
|
|
||||||
|
void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
|
||||||
|
const char* attr_name, char* value,
|
||||||
|
size_t max_length,
|
||||||
|
TF_Status* status) {
|
||||||
|
std::string v;
|
||||||
|
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
|
||||||
|
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
|
||||||
|
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||||
|
|
||||||
|
if (!status->status.ok()) return;
|
||||||
|
|
||||||
|
if (max_length <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
|
||||||
|
const char* attr_name,
|
||||||
|
char** values, size_t* lengths,
|
||||||
|
int max_values, void* storage,
|
||||||
|
size_t storage_size,
|
||||||
|
TF_Status* status) {
|
||||||
|
std::vector<std::string> v;
|
||||||
|
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
|
||||||
|
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
|
||||||
|
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||||
|
|
||||||
|
if (!status->status.ok()) return;
|
||||||
|
|
||||||
|
const auto len = std::min(max_values, static_cast<int>(v.size()));
|
||||||
|
char* p = static_cast<char*>(storage);
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
const std::string& s = v[i];
|
||||||
|
values[i] = p;
|
||||||
|
lengths[i] = s.size();
|
||||||
|
if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
|
||||||
|
status->status = InvalidArgument(
|
||||||
|
"Not enough storage to hold the requested list of strings");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
memcpy(values[i], s.data(), s.size());
|
||||||
|
p += s.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
|
||||||
|
const char* attr_name, TF_Status* status) {
|
||||||
|
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
|
||||||
|
return cc_ctx->HasAttr(attr_name);
|
||||||
|
}
|
||||||
|
|
||||||
TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
|
TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
|
||||||
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
|
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
|
||||||
|
@ -184,6 +184,24 @@ TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType(
|
|||||||
// Returns the step ID of the given context.
|
// Returns the step ID of the given context.
|
||||||
TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx);
|
TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx);
|
||||||
|
|
||||||
|
// Get the list_size and total_size of the attribute `attr_name` of `oper`.
|
||||||
|
// list_size - the length of the list.
|
||||||
|
// total_size - total size of the list.
|
||||||
|
// (1) If attr_type == TF_ATTR_STRING
|
||||||
|
// then total_size is the cumulative byte size
|
||||||
|
// of all the strings in the list.
|
||||||
|
// (3) If attr_type == TF_ATTR_SHAPE
|
||||||
|
// then total_size is the number of dimensions
|
||||||
|
// of the shape valued attribute, or -1
|
||||||
|
// if its rank is unknown.
|
||||||
|
// (4) If attr_type == TF_ATTR_SHAPE
|
||||||
|
// then total_size is the cumulative number
|
||||||
|
// of dimensions of all shapes in the list.
|
||||||
|
// (5) Otherwise, total_size is undefined.
|
||||||
|
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize(
|
||||||
|
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size,
|
||||||
|
int32_t* total_size, TF_Status* status);
|
||||||
|
|
||||||
// Interprets the named kernel construction attribute as a TF_DataType and
|
// Interprets the named kernel construction attribute as a TF_DataType and
|
||||||
// places it into *val. *status is set to TF_OK.
|
// places it into *val. *status is set to TF_OK.
|
||||||
//
|
//
|
||||||
@ -202,6 +220,112 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32(
|
|||||||
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
|
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Interprets the named kernel construction attribute as int64_t and
|
||||||
|
// places it into *val. *status is set to TF_OK.
|
||||||
|
//
|
||||||
|
// If the attribute could not be found or could not be interpreted as
|
||||||
|
// int64, *status is populated with an error.
|
||||||
|
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64(
|
||||||
|
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Interprets the named kernel construction attribute as float and
|
||||||
|
// places it into *val. *status is set to TF_OK.
|
||||||
|
//
|
||||||
|
// If the attribute could not be found or could not be interpreted as
|
||||||
|
// float, *status is populated with an error.
|
||||||
|
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat(
|
||||||
|
TF_OpKernelConstruction* ctx, const char* attr_name, float* val,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Interprets the named kernel construction attribute as bool and
|
||||||
|
// places it into *val. *status is set to TF_OK.
|
||||||
|
//
|
||||||
|
// If the attribute could not be found or could not be interpreted as
|
||||||
|
// bool, *status is populated with an error.
|
||||||
|
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool(
|
||||||
|
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Interprets the named kernel construction attribute as string and
|
||||||
|
// places it into *val. `val` must
|
||||||
|
// point to an array of length at least `max_length` (ideally set to
|
||||||
|
// total_size from TF_OpKernelConstruction_GetAttrSize(ctx,
|
||||||
|
// attr_name, list_size, total_size)). *status is set to TF_OK.
|
||||||
|
//
|
||||||
|
// If the attribute could not be found or could not be interpreted as
|
||||||
|
// string, *status is populated with an error.
|
||||||
|
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString(
|
||||||
|
TF_OpKernelConstruction* ctx, const char* attr_name, char* val,
|
||||||
|
size_t max_length, TF_Status* status);
|
||||||
|
|
||||||
|
// Interprets the named kernel construction attribute as a TF_DataType array and
|
||||||
|
// places it into *vals. *status is set to TF_OK.
|
||||||
|
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||||
|
// to list_size from
|
||||||
|
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||||
|
// total_size)).
|
||||||
|
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList(
|
||||||
|
TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals,
|
||||||
|
int max_vals, TF_Status* status);
|
||||||
|
|
||||||
|
// Interprets the named kernel construction attribute as int32_t array and
|
||||||
|
// places it into *vals. *status is set to TF_OK.
|
||||||
|
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||||
|
// to list_size from
|
||||||
|
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||||
|
// total_size)).
|
||||||
|
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List(
|
||||||
|
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals,
|
||||||
|
int max_vals, TF_Status* status);
|
||||||
|
|
||||||
|
// Interprets the named kernel construction attribute as int64_t array and
|
||||||
|
// places it into *vals. *status is set to TF_OK.
|
||||||
|
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||||
|
// to list_size from
|
||||||
|
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||||
|
// total_size)).
|
||||||
|
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List(
|
||||||
|
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals,
|
||||||
|
int max_vals, TF_Status* status);
|
||||||
|
|
||||||
|
// Interprets the named kernel construction attribute as float array and
|
||||||
|
// places it into *vals. *status is set to TF_OK.
|
||||||
|
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||||
|
// to list_size from
|
||||||
|
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||||
|
// total_size)).
|
||||||
|
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList(
|
||||||
|
TF_OpKernelConstruction* ctx, const char* attr_name, float* vals,
|
||||||
|
int max_vals, TF_Status* status);
|
||||||
|
|
||||||
|
// Interprets the named kernel construction attribute as bool array and
|
||||||
|
// places it into *vals. *status is set to TF_OK.
|
||||||
|
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||||
|
// to list_size from
|
||||||
|
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||||
|
// total_size)).
|
||||||
|
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList(
|
||||||
|
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals,
|
||||||
|
int max_vals, TF_Status* status);
|
||||||
|
|
||||||
|
// Interprets the named kernel construction attribute as string array and fills
|
||||||
|
// in `vals` and `lengths`, each of which must point to an array of length at
|
||||||
|
// least `max_values`. *status is set to TF_OK. The elements of values will
|
||||||
|
// point to addresses in `storage` which must be at least `storage_size` bytes
|
||||||
|
// in length. Ideally, max_values would be set to list_size and `storage` would
|
||||||
|
// be at least total_size, obtained from
|
||||||
|
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||||
|
// total_size).
|
||||||
|
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList(
|
||||||
|
TF_OpKernelConstruction* ctx, const char* attr_name, char** vals,
|
||||||
|
size_t* lengths, int max_values, void* storage, size_t storage_size,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Return true if the kernel construction has the attr_name
|
||||||
|
TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr(
|
||||||
|
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status);
|
||||||
|
|
||||||
// Returns the unique operation name for this OpKernel.
|
// Returns the unique operation name for this OpKernel.
|
||||||
TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
|
TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
|
||||||
TF_OpKernelConstruction* ctx);
|
TF_OpKernelConstruction* ctx);
|
||||||
|
@ -161,6 +161,337 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
|
|||||||
ASSERT_TRUE(delete_called);
|
ASSERT_TRUE(delete_called);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases.
|
||||||
|
// Registers two ops, each with a single attribute called 'Attr'.
|
||||||
|
// The attribute in one op will have a type 'type', the other
|
||||||
|
// will have list(type).
|
||||||
|
#define ATTR_TEST_REGISTER_OP(name, type) \
|
||||||
|
REGISTER_OP("TestKernelAttr" #name) \
|
||||||
|
.Attr("Attr: " #type) \
|
||||||
|
.SetShapeFn(tensorflow::shape_inference::UnknownShape); \
|
||||||
|
REGISTER_OP("TestKernelAttr" #name "List") \
|
||||||
|
.Attr("Attr: list(" #type ")") \
|
||||||
|
.SetShapeFn(tensorflow::shape_inference::UnknownShape)
|
||||||
|
ATTR_TEST_REGISTER_OP(String, string);
|
||||||
|
ATTR_TEST_REGISTER_OP(Int, int);
|
||||||
|
ATTR_TEST_REGISTER_OP(Float, float);
|
||||||
|
ATTR_TEST_REGISTER_OP(Bool, bool);
|
||||||
|
ATTR_TEST_REGISTER_OP(Type, type);
|
||||||
|
#undef ATTR_TEST_REGISTER_OP
|
||||||
|
|
||||||
|
// Helper macros for the TF_OpKernelConstruction_GetAttr* tests.
|
||||||
|
#define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \
|
||||||
|
do { \
|
||||||
|
int32_t list_size, total_size; \
|
||||||
|
TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, \
|
||||||
|
&total_size, status); \
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); \
|
||||||
|
EXPECT_EQ(expected_list_size, list_size); \
|
||||||
|
EXPECT_EQ(expected_total_size, total_size); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*);
|
||||||
|
class TestKernelAttr : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
TestKernelAttr() {}
|
||||||
|
~TestKernelAttr() {}
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> GetFakeKernelWithAttr(const char* op_name,
|
||||||
|
AttrValue v, Status* status) {
|
||||||
|
NodeDef def;
|
||||||
|
def.set_op(op_name);
|
||||||
|
def.set_name("FakeNode");
|
||||||
|
def.set_device("FakeDevice");
|
||||||
|
(*def.mutable_attr())["Attr"] = v;
|
||||||
|
return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1,
|
||||||
|
status);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetAttr(MyCreateFuncWithAttr MyCreateFuncAttr, const char* op_name,
|
||||||
|
AttrValue& v) {
|
||||||
|
TF_KernelBuilder* builder = TF_NewKernelBuilder(
|
||||||
|
op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc);
|
||||||
|
{
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TF_RegisterKernelBuilder("FakeNode", builder, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
}
|
||||||
|
Status status;
|
||||||
|
std::unique_ptr<OpKernel> kernel =
|
||||||
|
GetFakeKernelWithAttr(op_name, v, &status);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
ASSERT_NE(nullptr, kernel.get());
|
||||||
|
kernel->Compute(nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(delete_called);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestKernelAttr, String) {
|
||||||
|
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||||
|
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||||
|
s->created = true;
|
||||||
|
s->compute_called = false;
|
||||||
|
|
||||||
|
std::unique_ptr<char[]> val(new char[5]);
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||||
|
/*expected_total_size*/ 5);
|
||||||
|
TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(),
|
||||||
|
/*max_length*/ 5, status);
|
||||||
|
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
EXPECT_EQ("bunny", string(static_cast<const char*>(val.get()), 5));
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<void*>(s);
|
||||||
|
};
|
||||||
|
|
||||||
|
AttrValue v;
|
||||||
|
v.set_s("bunny");
|
||||||
|
SetAttr(my_create_func, "TestKernelAttrString", v);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestKernelAttr, StringList) {
|
||||||
|
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||||
|
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||||
|
s->created = true;
|
||||||
|
s->compute_called = false;
|
||||||
|
|
||||||
|
std::vector<string> list = {"bugs", "bunny", "duck"};
|
||||||
|
int list_total_size = 0;
|
||||||
|
for (const auto& s : list) {
|
||||||
|
list_total_size += s.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
std::unique_ptr<char*[]> values(new char*[list.size()]);
|
||||||
|
std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
|
||||||
|
std::unique_ptr<char[]> storage(new char[list_total_size]);
|
||||||
|
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(),
|
||||||
|
/*expected_total_size*/ list_total_size);
|
||||||
|
TF_OpKernelConstruction_GetAttrStringList(
|
||||||
|
ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(),
|
||||||
|
list_total_size, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < list.size(); ++i) {
|
||||||
|
EXPECT_EQ(list[i].size(), lens[i]) << i;
|
||||||
|
EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
|
||||||
|
<< i;
|
||||||
|
}
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<void*>(s);
|
||||||
|
};
|
||||||
|
|
||||||
|
AttrValue v;
|
||||||
|
auto attr_in = gtl::ArraySlice<StringPiece>({"bugs", "bunny", "duck"});
|
||||||
|
SetAttrValue(attr_in, &v);
|
||||||
|
SetAttr(my_create_func, "TestKernelAttrStringList", v);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestKernelAttr, Int) {
|
||||||
|
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||||
|
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||||
|
s->created = true;
|
||||||
|
s->compute_called = false;
|
||||||
|
|
||||||
|
int64_t val;
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||||
|
/*expected_total_size*/ -1);
|
||||||
|
TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
EXPECT_EQ(1234, val);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<void*>(s);
|
||||||
|
};
|
||||||
|
|
||||||
|
AttrValue v;
|
||||||
|
v.set_i(1234);
|
||||||
|
SetAttr(my_create_func, "TestKernelAttrInt", v);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestKernelAttr, IntList) {
|
||||||
|
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||||
|
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||||
|
s->created = true;
|
||||||
|
s->compute_called = false;
|
||||||
|
|
||||||
|
const int64_t list[] = {1, 2, 3, 4};
|
||||||
|
const size_t list_size = TF_ARRAYSIZE(list);
|
||||||
|
int64_t values[list_size];
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
|
||||||
|
/*expected_total_size*/ -1);
|
||||||
|
TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size,
|
||||||
|
status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
EXPECT_TRUE(
|
||||||
|
std::equal(std::begin(list), std::end(list), std::begin(values)));
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<void*>(s);
|
||||||
|
};
|
||||||
|
|
||||||
|
AttrValue v;
|
||||||
|
auto attr_in = gtl::ArraySlice<int64>({1, 2, 3, 4});
|
||||||
|
SetAttrValue(attr_in, &v);
|
||||||
|
SetAttr(my_create_func, "TestKernelAttrIntList", v);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestKernelAttr, Float) {
|
||||||
|
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||||
|
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||||
|
s->created = true;
|
||||||
|
s->compute_called = false;
|
||||||
|
|
||||||
|
float val;
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||||
|
/*expected_total_size*/ -1);
|
||||||
|
TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
EXPECT_FLOAT_EQ(2.718, val);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<void*>(s);
|
||||||
|
};
|
||||||
|
|
||||||
|
AttrValue v;
|
||||||
|
v.set_f(2.718);
|
||||||
|
SetAttr(my_create_func, "TestKernelAttrFloat", v);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestKernelAttr, FloatList) {
|
||||||
|
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||||
|
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||||
|
s->created = true;
|
||||||
|
s->compute_called = false;
|
||||||
|
|
||||||
|
const float list[] = {1.414, 2.718, 3.1415};
|
||||||
|
const size_t list_size = TF_ARRAYSIZE(list);
|
||||||
|
float values[list_size];
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
|
||||||
|
/*expected_total_size*/ -1);
|
||||||
|
TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size,
|
||||||
|
status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
EXPECT_TRUE(
|
||||||
|
std::equal(std::begin(list), std::end(list), std::begin(values)));
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<void*>(s);
|
||||||
|
};
|
||||||
|
|
||||||
|
AttrValue v;
|
||||||
|
auto attr_in = gtl::ArraySlice<float>({1.414, 2.718, 3.1415});
|
||||||
|
SetAttrValue(attr_in, &v);
|
||||||
|
SetAttr(my_create_func, "TestKernelAttrFloatList", v);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestKernelAttr, Bool) {
|
||||||
|
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||||
|
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||||
|
s->created = true;
|
||||||
|
s->compute_called = false;
|
||||||
|
|
||||||
|
unsigned char val;
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||||
|
/*expected_total_size*/ -1);
|
||||||
|
TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
EXPECT_EQ(1, val);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<void*>(s);
|
||||||
|
};
|
||||||
|
|
||||||
|
AttrValue v;
|
||||||
|
v.set_b(1);
|
||||||
|
SetAttr(my_create_func, "TestKernelAttrBool", v);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestKernelAttr, BoolList) {
|
||||||
|
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||||
|
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||||
|
s->created = true;
|
||||||
|
s->compute_called = false;
|
||||||
|
|
||||||
|
const unsigned char list[] = {1, 0, 1, 0};
|
||||||
|
const size_t list_size = TF_ARRAYSIZE(list);
|
||||||
|
unsigned char values[list_size];
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
|
||||||
|
/*expected_total_size*/ -1);
|
||||||
|
TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size,
|
||||||
|
status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
EXPECT_TRUE(
|
||||||
|
std::equal(std::begin(list), std::end(list), std::begin(values)));
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<void*>(s);
|
||||||
|
};
|
||||||
|
|
||||||
|
AttrValue v;
|
||||||
|
auto attr_in = gtl::ArraySlice<bool>({1, 0, 1, 0});
|
||||||
|
SetAttrValue(attr_in, &v);
|
||||||
|
SetAttr(my_create_func, "TestKernelAttrBoolList", v);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestKernelAttr, Type) {
|
||||||
|
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||||
|
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||||
|
s->created = true;
|
||||||
|
s->compute_called = false;
|
||||||
|
|
||||||
|
TF_DataType val;
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||||
|
/*expected_total_size*/ -1);
|
||||||
|
TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
EXPECT_EQ(TF_FLOAT, val);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<void*>(s);
|
||||||
|
};
|
||||||
|
|
||||||
|
AttrValue v;
|
||||||
|
v.set_type(DT_FLOAT);
|
||||||
|
SetAttr(my_create_func, "TestKernelAttrType", v);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestKernelAttr, TypeList) {
|
||||||
|
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||||
|
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||||
|
s->created = true;
|
||||||
|
s->compute_called = false;
|
||||||
|
|
||||||
|
const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
|
||||||
|
const size_t list_size = TF_ARRAYSIZE(list);
|
||||||
|
TF_DataType values[list_size];
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
|
||||||
|
/*expected_total_size*/ -1);
|
||||||
|
TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size,
|
||||||
|
status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
EXPECT_TRUE(
|
||||||
|
std::equal(std::begin(list), std::end(list), std::begin(values)));
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return static_cast<void*>(s);
|
||||||
|
};
|
||||||
|
|
||||||
|
AttrValue v;
|
||||||
|
auto attr_in =
|
||||||
|
gtl::ArraySlice<DataType>({DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128});
|
||||||
|
SetAttrValue(attr_in, &v);
|
||||||
|
SetAttr(my_create_func, "TestKernelAttrTypeList", v);
|
||||||
|
}
|
||||||
|
#undef EXPECT_TF_SIZE
|
||||||
|
|
||||||
class DummyDevice : public DeviceBase {
|
class DummyDevice : public DeviceBase {
|
||||||
public:
|
public:
|
||||||
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
||||||
|
@ -65,6 +65,7 @@ MAP_HLO_TO_LHLO(MinOp);
|
|||||||
MAP_HLO_TO_LHLO(MulOp);
|
MAP_HLO_TO_LHLO(MulOp);
|
||||||
MAP_HLO_TO_LHLO(NegOp);
|
MAP_HLO_TO_LHLO(NegOp);
|
||||||
MAP_HLO_TO_LHLO(NotOp);
|
MAP_HLO_TO_LHLO(NotOp);
|
||||||
|
MAP_HLO_TO_LHLO(OrOp);
|
||||||
MAP_HLO_TO_LHLO(RealOp);
|
MAP_HLO_TO_LHLO(RealOp);
|
||||||
MAP_HLO_TO_LHLO(ReduceOp);
|
MAP_HLO_TO_LHLO(ReduceOp);
|
||||||
MAP_HLO_TO_LHLO(ReshapeOp);
|
MAP_HLO_TO_LHLO(ReshapeOp);
|
||||||
@ -81,6 +82,7 @@ MAP_HLO_TO_LHLO(SqrtOp);
|
|||||||
MAP_HLO_TO_LHLO(SubOp);
|
MAP_HLO_TO_LHLO(SubOp);
|
||||||
MAP_HLO_TO_LHLO(TanhOp);
|
MAP_HLO_TO_LHLO(TanhOp);
|
||||||
MAP_HLO_TO_LHLO(TransposeOp);
|
MAP_HLO_TO_LHLO(TransposeOp);
|
||||||
|
MAP_HLO_TO_LHLO(XorOp);
|
||||||
|
|
||||||
#undef MAP_HLO_TO_LHLO
|
#undef MAP_HLO_TO_LHLO
|
||||||
|
|
||||||
|
@ -481,6 +481,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::OrOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
@ -580,6 +589,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
|
|||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace impl
|
} // namespace impl
|
||||||
|
|
||||||
struct HloOpToStdScalarOp {
|
struct HloOpToStdScalarOp {
|
||||||
|
@ -629,6 +629,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
|||||||
HloToLhloOpConverter<mhlo::MulOp>,
|
HloToLhloOpConverter<mhlo::MulOp>,
|
||||||
HloToLhloOpConverter<mhlo::NegOp>,
|
HloToLhloOpConverter<mhlo::NegOp>,
|
||||||
HloToLhloOpConverter<mhlo::NotOp>,
|
HloToLhloOpConverter<mhlo::NotOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::OrOp>,
|
||||||
HloToLhloOpConverter<mhlo::RealOp>,
|
HloToLhloOpConverter<mhlo::RealOp>,
|
||||||
HloToLhloOpConverter<mhlo::RemOp>,
|
HloToLhloOpConverter<mhlo::RemOp>,
|
||||||
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
||||||
@ -644,6 +645,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
|||||||
HloToLhloOpConverter<mhlo::SubOp>,
|
HloToLhloOpConverter<mhlo::SubOp>,
|
||||||
HloToLhloOpConverter<mhlo::TanhOp>,
|
HloToLhloOpConverter<mhlo::TanhOp>,
|
||||||
HloToLhloOpConverter<mhlo::TransposeOp>,
|
HloToLhloOpConverter<mhlo::TransposeOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::XorOp>,
|
||||||
HloToLhloReduceOpConverter,
|
HloToLhloReduceOpConverter,
|
||||||
HloToLhloReturnOpConverter,
|
HloToLhloReturnOpConverter,
|
||||||
HloToLhloTensorLoadOpConverter,
|
HloToLhloTensorLoadOpConverter,
|
||||||
|
@ -927,12 +927,14 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||||||
PointwiseToLinalgConverter<lmhlo::ExpOp>,
|
PointwiseToLinalgConverter<lmhlo::ExpOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::FloorOp>,
|
PointwiseToLinalgConverter<lmhlo::FloorOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
||||||
|
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::LogOp>,
|
PointwiseToLinalgConverter<lmhlo::LogOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::MaxOp>,
|
PointwiseToLinalgConverter<lmhlo::MaxOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::MulOp>,
|
PointwiseToLinalgConverter<lmhlo::MulOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::NotOp>,
|
PointwiseToLinalgConverter<lmhlo::NotOp>,
|
||||||
|
PointwiseToLinalgConverter<lmhlo::OrOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
||||||
@ -945,7 +947,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||||||
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::SubOp>,
|
PointwiseToLinalgConverter<lmhlo::SubOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::TanhOp>,
|
PointwiseToLinalgConverter<lmhlo::TanhOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
PointwiseToLinalgConverter<lmhlo::XorOp>,
|
||||||
ReduceConverter,
|
ReduceConverter,
|
||||||
ReshapeOpConverter<lmhlo::ReshapeOp>,
|
ReshapeOpConverter<lmhlo::ReshapeOp>,
|
||||||
ReverseConverter<lmhlo::ReverseOp>,
|
ReverseConverter<lmhlo::ReverseOp>,
|
||||||
@ -1042,12 +1044,14 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
|
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::OrOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||||
@ -1055,11 +1059,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||||||
PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
|
PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
|
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
|
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::SignOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
PointwiseToLinalgConverter<mhlo::XorOp, false>,
|
||||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||||
ReverseConverter<mhlo::ReverseOp, false>,
|
ReverseConverter<mhlo::ReverseOp, false>,
|
||||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||||
|
@ -42,11 +42,11 @@ namespace {
|
|||||||
sep fn(SqrtOp) sep fn(TanhOp)
|
sep fn(SqrtOp) sep fn(TanhOp)
|
||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
// TODO(herhut): Generate these out of op definitions.
|
||||||
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
|
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
|
||||||
fn(AddOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) sep fn(MaxOp) \
|
fn(AddOp) sep fn(AndOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) \
|
||||||
sep fn(MinOp) sep fn(MulOp) sep fn(PowOp) sep fn(RemOp) \
|
sep fn(MaxOp) sep fn(MinOp) sep fn(MulOp) sep fn(OrOp) sep fn(PowOp) \
|
||||||
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
sep fn(RemOp) sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
||||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
|
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
|
||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
// TODO(herhut): Generate these out of op definitions.
|
||||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||||
|
@ -316,6 +316,20 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @and
|
||||||
|
func @and(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
|
||||||
|
%result: memref<2x2xi32>) {
|
||||||
|
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
|
||||||
|
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
|
||||||
|
%tensor_result = "mhlo.and"(%tensor_operand0, %tensor_operand1)
|
||||||
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
// CHECK: "lmhlo.and"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @ceil
|
// CHECK-LABEL: func @ceil
|
||||||
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
@ -389,6 +403,20 @@ func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @or
|
||||||
|
func @or(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
|
||||||
|
%result: memref<2x2xi32>) {
|
||||||
|
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
|
||||||
|
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
|
||||||
|
%tensor_result = "mhlo.or"(%tensor_operand0, %tensor_operand1)
|
||||||
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
// CHECK: "lmhlo.or"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @rsqrt
|
// CHECK-LABEL: func @rsqrt
|
||||||
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
@ -480,7 +508,8 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @remainder
|
// CHECK-LABEL: func @remainder
|
||||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||||
|
%result: memref<2x2xf32>) {
|
||||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||||
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
|
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||||
@ -492,6 +521,20 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @xor
|
||||||
|
func @xor(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
|
||||||
|
%result: memref<2x2xi32>) {
|
||||||
|
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
|
||||||
|
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
|
||||||
|
%tensor_result = "mhlo.xor"(%tensor_operand0, %tensor_operand1)
|
||||||
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
// CHECK: "lmhlo.xor"(%{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// Dynamic shape binary element-wise operation.
|
// Dynamic shape binary element-wise operation.
|
||||||
// CHECK-LABEL: func @add_dyn
|
// CHECK-LABEL: func @add_dyn
|
||||||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||||
|
@ -194,6 +194,30 @@ func @integer_and(%lhs: tensor<2x2xi32>,
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @integer_or
|
||||||
|
func @integer_or(%lhs: tensor<2x2xi32>,
|
||||||
|
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: or
|
||||||
|
%0 = "mhlo.or"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||||
|
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
return %0 : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @integer_xor
|
||||||
|
func @integer_xor(%lhs: tensor<2x2xi32>,
|
||||||
|
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: xor
|
||||||
|
%0 = "mhlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||||
|
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
return %0 : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @float_cmp
|
// CHECK-LABEL: func @float_cmp
|
||||||
func @float_cmp(%lhs: tensor<2x2xf32>,
|
func @float_cmp(%lhs: tensor<2x2xf32>,
|
||||||
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
|
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
|
||||||
|
@ -509,7 +509,7 @@ Operation* BuildVariableOp(const tflite::TensorT& tensor,
|
|||||||
return op.getOperation();
|
return op.getOperation();
|
||||||
}
|
}
|
||||||
auto op = builder.create<tfl::ConstOp>(loc, value);
|
auto op = builder.create<tfl::ConstOp>(loc, value);
|
||||||
if (!tensor.quantization->min.empty()) {
|
if (tensor.quantization && !tensor.quantization->min.empty()) {
|
||||||
if (auto stats_op =
|
if (auto stats_op =
|
||||||
ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
|
ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
|
||||||
return stats_op;
|
return stats_op;
|
||||||
|
@ -1977,6 +1977,7 @@ cc_library(
|
|||||||
hdrs = ["utils/bridge_logger.h"],
|
hdrs = ["utils/bridge_logger.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":dump_mlir_util",
|
":dump_mlir_util",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
@ -5028,6 +5028,26 @@ Table initializer that takes two tensors for keys and values respectively.
|
|||||||
TF_DerivedOperandTypeAttr Tkey = TF_DerivedOperandTypeAttr<1>;
|
TF_DerivedOperandTypeAttr Tkey = TF_DerivedOperandTypeAttr<1>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_InplaceAddOp : TF_Op<"InplaceAdd", [AllTypesMatch<["x", "y"]>, NoSideEffect]> {
|
||||||
|
let summary = "Adds v into specified rows of x.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Computes y = x; y[i, :] += v; return y.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_Tensor:$x,
|
||||||
|
TF_Int32Tensor:$i,
|
||||||
|
TF_Tensor:$v
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Tensor:$y
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
|
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
|
||||||
let summary = "Updates specified rows 'i' with values 'v'.";
|
let summary = "Updates specified rows 'i' with values 'v'.";
|
||||||
|
|
||||||
@ -5374,6 +5394,37 @@ def TF_IteratorV2Op : TF_Op<"IteratorV2", []> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_KthOrderStatisticOp : TF_Op<"KthOrderStatistic", [NoSideEffect]> {
|
||||||
|
let summary = "Computes the Kth order statistic of a data set. The current";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
implementation uses a binary search requiring exactly 32 passes over
|
||||||
|
the input data. The running time is linear with respect to input
|
||||||
|
size. The median-of-medians algorithm is probably faster, but is
|
||||||
|
difficult to implement efficiently in XLA. The implementation imposes
|
||||||
|
a total ordering on floats. The ordering is consistent with the usual
|
||||||
|
partial order. Positive NaNs are greater than positive
|
||||||
|
infinity. Negative NaNs are less than negative infinity. NaNs with
|
||||||
|
distinct payloads are treated as distinct. Subnormal numbers are
|
||||||
|
preserved (not flushed to zero). Positive infinity is greater than all
|
||||||
|
numbers. Negative infinity is less than all numbers. Positive is
|
||||||
|
greater than negative zero. There are less than k values greater than
|
||||||
|
the kth order statistic. There are at least k values greater than or
|
||||||
|
equal to the Kth order statistic. The semantics are not the same as
|
||||||
|
top_k_unique.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_Float32Tensor:$input,
|
||||||
|
|
||||||
|
I64Attr:$k
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Float32Tensor:$output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> {
|
def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> {
|
||||||
let summary = "L2 Loss.";
|
let summary = "L2 Loss.";
|
||||||
|
|
||||||
@ -6505,6 +6556,27 @@ iterator in `iterator` to the first element of `dataset`.
|
|||||||
let results = (outs);
|
let results = (outs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_MakeUniqueOp : TF_Op<"MakeUnique", [NoSideEffect]> {
|
||||||
|
let summary = [{
|
||||||
|
Make all elements in the non-Batch dimension unique, but \"close\" to
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
their initial value. Never returns a sub-normal number. Never returns
|
||||||
|
zero. The sign of each input element is always identical to the sign
|
||||||
|
of the corresponding output element. Behavior for infinite elements is
|
||||||
|
undefined. Behavior for subnormal elements is undefined.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_Float32Tensor:$input
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Float32Tensor:$output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
|
def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Multiply the matrix "a" by the matrix "b".
|
Multiply the matrix "a" by the matrix "b".
|
||||||
@ -15234,6 +15306,36 @@ array([[1, 2, 3, 1, 2, 3],
|
|||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_TopKUniqueOp : TF_Op<"TopKUnique", [NoSideEffect]> {
|
||||||
|
let summary = "Returns the TopK unique values in the array in sorted order.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
The running time is proportional to the product of K and the input
|
||||||
|
size. Sorting the whole array is more efficient for sufficiently large
|
||||||
|
values of K. The median-of-medians algorithm is probably faster, but
|
||||||
|
difficult to implement efficiently in XLA. If there are fewer than K
|
||||||
|
unique numbers (not NANs), the results are padded with negative
|
||||||
|
infinity. NaNs are never returned. Subnormal numbers are flushed to
|
||||||
|
zero. If an element appears at multiple indices, the highest index is
|
||||||
|
returned. If a TopK element never appears in the input due to padding
|
||||||
|
values, the indices are padded with negative one. If a padding value
|
||||||
|
appears in the input and padding is needed, the highest index of the
|
||||||
|
padding value will be returned. The semantics are not the same as
|
||||||
|
kth_order_statistic.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_Float32Tensor:$input,
|
||||||
|
|
||||||
|
I64Attr:$k
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Float32Tensor:$topk,
|
||||||
|
TF_Int32Tensor:$topk_indices
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> {
|
def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Finds values and indices of the `k` largest elements for the last dimension.
|
Finds values and indices of the `k` largest elements for the last dimension.
|
||||||
@ -15269,6 +15371,29 @@ If two elements are equal, the lower-index element appears first.
|
|||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_TopKWithUniqueOp : TF_Op<"TopKWithUnique", [NoSideEffect]> {
|
||||||
|
let summary = "Returns the TopK values in the array in sorted order.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
This is a combination of MakeUnique and TopKUnique. The returned top-K will
|
||||||
|
have its lower bits replaced by iota, thus it will be close to the original
|
||||||
|
value but not exactly the same. The running time is proportional to the product
|
||||||
|
of K and the input size. NaNs are never returned. Subnormal numbers are flushed
|
||||||
|
to zero.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_Float32Tensor:$input,
|
||||||
|
|
||||||
|
I64Attr:$k
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Float32Tensor:$topk,
|
||||||
|
TF_Int32Tensor:$topk_indices
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> {
|
def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> {
|
||||||
let summary = "Shuffle dimensions of x according to a permutation.";
|
let summary = "Shuffle dimensions of x according to a permutation.";
|
||||||
|
|
||||||
|
@ -532,7 +532,11 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
auto ranked_type = type.dyn_cast<RankedTensorType>();
|
auto ranked_type = type.dyn_cast<RankedTensorType>();
|
||||||
if (!ranked_type) return {};
|
if (!ranked_type) return {};
|
||||||
|
|
||||||
auto output_type = getType().cast<ShapedType>();
|
// DenseIntElementsAttr::get requires the output type be ranked with static
|
||||||
|
// shape.
|
||||||
|
auto output_type = getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!output_type || !output_type.hasStaticShape()) return {};
|
||||||
|
|
||||||
int32_t rank = ranked_type.getRank();
|
int32_t rank = ranked_type.getRank();
|
||||||
return DenseIntElementsAttr::get(output_type, rank);
|
return DenseIntElementsAttr::get(output_type, rank);
|
||||||
}
|
}
|
||||||
|
@ -904,6 +904,20 @@ func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> {
|
|||||||
return %0 : tensor<i32>
|
return %0 : tensor<i32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: testRankOfRankedTensorUnrankedOutput
|
||||||
|
func @testRankOfRankedTensorUnrankedOutput(%arg0 : tensor<4x3x2xf32>) -> tensor<*xi32> {
|
||||||
|
// Regression test to make sure we don't crash in this case.
|
||||||
|
%0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<*xi32>
|
||||||
|
return %0 : tensor<*xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: testRankOfRankedTensorDynamicShapeOutput
|
||||||
|
func @testRankOfRankedTensorDynamicShapeOutput(%arg0 : tensor<4x3x2xf32>) -> tensor<?xi32> {
|
||||||
|
// Regression test to make sure we don't crash in this case.
|
||||||
|
%0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<?xi32>
|
||||||
|
return %0 : tensor<?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @foldFill
|
// CHECK-LABEL: @foldFill
|
||||||
func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
|
func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
|
||||||
%0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
%0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
@ -1627,8 +1627,10 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
int64_t producer = producer_or.ValueOrDie();
|
int64_t producer = producer_or.ValueOrDie();
|
||||||
|
// TODO(jpienaar): Clean up propagate_caller_callee_constants if it is no
|
||||||
|
// longer needed.
|
||||||
ShapeInference context(producer, module.getContext(),
|
ShapeInference context(producer, module.getContext(),
|
||||||
/*propagate_caller_callee_constants=*/true);
|
/*propagate_caller_callee_constants=*/false);
|
||||||
if (auto main = module.lookupSymbol<mlir::FuncOp>("main"))
|
if (auto main = module.lookupSymbol<mlir::FuncOp>("main"))
|
||||||
context.enqueue(main);
|
context.enqueue(main);
|
||||||
for (auto func : module.getOps<FuncOp>()) context.enqueue(func);
|
for (auto func : module.getOps<FuncOp>()) context.enqueue(func);
|
||||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
@ -23,17 +26,30 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Counter is used as a prefix for filenames.
|
||||||
|
static std::atomic<int> log_counter(0);
|
||||||
|
|
||||||
BridgeLoggerConfig::BridgeLoggerConfig(bool print_module_scope,
|
BridgeLoggerConfig::BridgeLoggerConfig(bool print_module_scope,
|
||||||
bool print_after_only_on_change)
|
bool print_after_only_on_change)
|
||||||
: mlir::PassManager::IRPrinterConfig(print_module_scope,
|
: mlir::PassManager::IRPrinterConfig(print_module_scope,
|
||||||
print_after_only_on_change) {}
|
print_after_only_on_change) {
|
||||||
|
const char* log_pass_patterns = getenv("MLIR_BRIDGE_LOG_PASS_PATTERNS");
|
||||||
|
if (log_pass_patterns) {
|
||||||
|
log_pass_patterns_ =
|
||||||
|
absl::StrSplit(log_pass_patterns, ',', absl::SkipWhitespace());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Logs op to file with name of format `mlir_bridge-pass_name-file_suffix.mlir`.
|
// Logs op to file with name of format
|
||||||
|
// `<log_counter>_mlir_bridge_<pass_name>_<file_suffix>.mlir`.
|
||||||
inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback,
|
inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback,
|
||||||
mlir::Pass* pass, mlir::Operation* op,
|
mlir::Pass* pass, mlir::Operation* op,
|
||||||
llvm::StringRef file_suffix) {
|
llvm::StringRef file_suffix) {
|
||||||
std::string name =
|
std::string pass_name = pass->getName().str();
|
||||||
llvm::formatv("mlir_bridge_{0}_{1}", pass->getName(), file_suffix).str();
|
|
||||||
|
// Add 4-digit counter as prefix so the order of the passes is obvious.
|
||||||
|
std::string name = llvm::formatv("{0,0+4}_mlir_bridge_{1}_{2}", log_counter++,
|
||||||
|
pass_name, file_suffix);
|
||||||
|
|
||||||
std::unique_ptr<llvm::raw_ostream> os;
|
std::unique_ptr<llvm::raw_ostream> os;
|
||||||
std::string filepath;
|
std::string filepath;
|
||||||
@ -44,13 +60,30 @@ inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback,
|
|||||||
void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass,
|
void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass,
|
||||||
mlir::Operation* operation,
|
mlir::Operation* operation,
|
||||||
PrintCallbackFn print_callback) {
|
PrintCallbackFn print_callback) {
|
||||||
Log(print_callback, pass, operation, "before");
|
if (should_print(pass)) Log(print_callback, pass, operation, "before");
|
||||||
}
|
}
|
||||||
|
|
||||||
void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass,
|
void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass,
|
||||||
mlir::Operation* operation,
|
mlir::Operation* operation,
|
||||||
PrintCallbackFn print_callback) {
|
PrintCallbackFn print_callback) {
|
||||||
Log(print_callback, pass, operation, "after");
|
if (should_print(pass)) Log(print_callback, pass, operation, "after");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BridgeLoggerConfig::should_print(mlir::Pass* pass) {
|
||||||
|
if (log_pass_patterns_.empty()) return true;
|
||||||
|
|
||||||
|
std::string pass_name = pass->getName().str();
|
||||||
|
for (const auto& pattern : log_pass_patterns_) {
|
||||||
|
if (pass_name.find(pattern) != std::string::npos) {
|
||||||
|
// pattern matches pass
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// no pattern matches pass
|
||||||
|
VLOG(2) << "Not logging pass " << pass_name
|
||||||
|
<< " because it does not match any pattern in "
|
||||||
|
"MLIR_BRIDGE_LOG_PASS_PATTERNS";
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BridgeTimingConfig::printTiming(PrintCallbackFn printCallback) {
|
void BridgeTimingConfig::printTiming(PrintCallbackFn printCallback) {
|
||||||
|
@ -23,7 +23,11 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Logger for logging/dumping MLIR modules before and after passes in bridge
|
// Logger for logging/dumping MLIR modules before and after passes in bridge
|
||||||
// targeting TPUs.
|
// targeting TPUs. The passes being logged can be restricted via environment
|
||||||
|
// variable `MLIR_BRIDGE_LOG_PASS_PATTERNS` which is interpreted as a comma-
|
||||||
|
// separated list of strings, and only passes whose name contains any of those
|
||||||
|
// strings as a substring are logged (no regex support). If
|
||||||
|
// `MLIR_BRIDGE_LOG_PASS_PATTERNS` is not defined, then all passes are logged.
|
||||||
class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig {
|
class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig {
|
||||||
public:
|
public:
|
||||||
explicit BridgeLoggerConfig(bool print_module_scope = false,
|
explicit BridgeLoggerConfig(bool print_module_scope = false,
|
||||||
@ -42,6 +46,14 @@ class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig {
|
|||||||
// with the stream to dump into.
|
// with the stream to dump into.
|
||||||
void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
|
void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
|
||||||
PrintCallbackFn print_callback) override;
|
PrintCallbackFn print_callback) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool should_print(mlir::Pass *pass);
|
||||||
|
|
||||||
|
// Only print passes that match any of these patterns. A pass matches a
|
||||||
|
// pattern if its name contains the pattern as a substring. If
|
||||||
|
// `log_pass_patterns_` is empty, print all passes.
|
||||||
|
std::vector<std::string> log_pass_patterns_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Logger for logging/dumping pass pipeline timings after completion.
|
// Logger for logging/dumping pass pipeline timings after completion.
|
||||||
|
@ -516,6 +516,11 @@ class ConvertToHloModule {
|
|||||||
//
|
//
|
||||||
// TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
|
// TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
|
||||||
LogicalResult Run() {
|
LogicalResult Run() {
|
||||||
|
auto main = module_.lookupSymbol<mlir::FuncOp>("main");
|
||||||
|
if (!main)
|
||||||
|
return module_.emitError(
|
||||||
|
"conversion requires module with `main` function");
|
||||||
|
|
||||||
for (auto func : module_.getOps<FuncOp>()) {
|
for (auto func : module_.getOps<FuncOp>()) {
|
||||||
if (func.empty()) continue;
|
if (func.empty()) continue;
|
||||||
if (failed(RunOnFunction(func))) return failure();
|
if (failed(RunOnFunction(func))) return failure();
|
||||||
@ -539,8 +544,11 @@ class ConvertToHloModule {
|
|||||||
xla::XlaComputation* result);
|
xla::XlaComputation* result);
|
||||||
|
|
||||||
::xla::HloModuleProto ConsumeMainProto() {
|
::xla::HloModuleProto ConsumeMainProto() {
|
||||||
return lowered_computation_[module_.lookupSymbol<mlir::FuncOp>("main")]
|
auto main = module_.lookupSymbol<mlir::FuncOp>("main");
|
||||||
.proto();
|
// This is an invariant check as Run returns failure if there is no main
|
||||||
|
// function and so the main proto shouldn't be consumed in that case.
|
||||||
|
CHECK(main) << "requires module to have main function"; // Crash Ok.
|
||||||
|
return lowered_computation_[main].proto();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lower function call to HLO call instruction
|
// Lower function call to HLO call instruction
|
||||||
|
@ -174,6 +174,13 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
|||||||
return %0: tensor<4xi32>
|
return %0: tensor<4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @bitwise_xor
|
||||||
|
func @bitwise_xor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
// CHECK-NEXT: mhlo.xor
|
||||||
|
%0 = "tf.BitwiseXor"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
return %0: tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @bitwise_and
|
// CHECK-LABEL: func @bitwise_and
|
||||||
func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||||
// CHECK-NEXT: mhlo.and
|
// CHECK-NEXT: mhlo.and
|
||||||
|
@ -835,7 +835,14 @@ func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @floordiv_unranked
|
// CHECK-LABEL: func @floordiv_unranked
|
||||||
func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK-NOT: tf.FloorDiv
|
||||||
|
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
return %0: tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @floordiv_int
|
||||||
|
func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||||
// CHECK: tf.FloorDiv
|
// CHECK: tf.FloorDiv
|
||||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||||
return %0: tensor<*xi32>
|
return %0: tensor<*xi32>
|
||||||
@ -894,7 +901,7 @@ func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
|
|||||||
|
|
||||||
// CHECK-LABEL: func @floormod_unranked
|
// CHECK-LABEL: func @floormod_unranked
|
||||||
func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||||
// CHECK: tf.FloorMod
|
// CHECK-NOT: tf.FloorMod
|
||||||
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||||
return %0: tensor<*xi32>
|
return %0: tensor<*xi32>
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
// RUN: not tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s 2>&1 | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: conversion requires module with `main`
|
||||||
|
func @non_main() {
|
||||||
|
%0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||||
|
return
|
||||||
|
}
|
@ -114,7 +114,7 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>;
|
|||||||
// Performs a substitution of FloorDiv, pseudo code below:
|
// Performs a substitution of FloorDiv, pseudo code below:
|
||||||
//
|
//
|
||||||
// return floor(div(x, y))
|
// return floor(div(x, y))
|
||||||
def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r),
|
||||||
(HLO_FloorOp
|
(HLO_FloorOp
|
||||||
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))),
|
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))),
|
||||||
[(IEEEFloatTensor $l)]>;
|
[(IEEEFloatTensor $l)]>;
|
||||||
@ -166,7 +166,7 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
|||||||
// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
|
// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
|
||||||
// Requires static shaped inputs to create constant splats and computation of
|
// Requires static shaped inputs to create constant splats and computation of
|
||||||
// broadcast attributes.
|
// broadcast attributes.
|
||||||
def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r),
|
||||||
(HLO_SelectOp
|
(HLO_SelectOp
|
||||||
(HLOClient_BroadcastAndOp
|
(HLOClient_BroadcastAndOp
|
||||||
(HLOClient_BroadcastCompareOp
|
(HLOClient_BroadcastCompareOp
|
||||||
@ -193,14 +193,15 @@ def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class DirectLogicalBinaryPat<Op FromOp, Op ToOp>
|
class DirectLogicalBinaryPat<Op FromOp, Op ToOp>
|
||||||
: Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
: Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
|
||||||
(ToOp $l, $r, (BinBroadcastDimensions $l, $r)),
|
(ToOp $l, $r, (BinBroadcastDimensions $l, $r)),
|
||||||
[(SignedIntTensor $l)]>;
|
[(SignedIntTensor $l)]>;
|
||||||
|
|
||||||
foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp],
|
foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp],
|
||||||
[TF_LogicalOrOp, HLOClient_BroadcastOrOp],
|
[TF_LogicalOrOp, HLOClient_BroadcastOrOp],
|
||||||
|
[TF_BitwiseAndOp, HLOClient_BroadcastAndOp],
|
||||||
[TF_BitwiseOrOp, HLOClient_BroadcastOrOp],
|
[TF_BitwiseOrOp, HLOClient_BroadcastOrOp],
|
||||||
[TF_BitwiseAndOp, HLOClient_BroadcastAndOp]] in
|
[TF_BitwiseXorOp, HLOClient_BroadcastXorOp]] in
|
||||||
def : DirectLogicalBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
|
def : DirectLogicalBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -154,9 +154,11 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
|||||||
TypeID::get<TF::IgammaOp>(),
|
TypeID::get<TF::IgammaOp>(),
|
||||||
TypeID::get<TF::IgammacOp>(),
|
TypeID::get<TF::IgammacOp>(),
|
||||||
TypeID::get<TF::IgammaGradAOp>(),
|
TypeID::get<TF::IgammaGradAOp>(),
|
||||||
|
TypeID::get<TF::InplaceAddOp>(),
|
||||||
TypeID::get<TF::InTopKV2Op>(),
|
TypeID::get<TF::InTopKV2Op>(),
|
||||||
TypeID::get<TF::InvertOp>(),
|
TypeID::get<TF::InvertOp>(),
|
||||||
TypeID::get<TF::InvOp>(),
|
TypeID::get<TF::InvOp>(),
|
||||||
|
TypeID::get<TF::KthOrderStatisticOp>(),
|
||||||
TypeID::get<TF::LRNOp>(),
|
TypeID::get<TF::LRNOp>(),
|
||||||
TypeID::get<TF::LRNGradOp>(),
|
TypeID::get<TF::LRNGradOp>(),
|
||||||
TypeID::get<TF::LeakyReluGradOp>(),
|
TypeID::get<TF::LeakyReluGradOp>(),
|
||||||
@ -170,6 +172,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
|||||||
TypeID::get<TF::LogicalOrOp>(),
|
TypeID::get<TF::LogicalOrOp>(),
|
||||||
TypeID::get<TF::LogOp>(),
|
TypeID::get<TF::LogOp>(),
|
||||||
TypeID::get<TF::LowerBoundOp>(),
|
TypeID::get<TF::LowerBoundOp>(),
|
||||||
|
TypeID::get<TF::MakeUniqueOp>(),
|
||||||
TypeID::get<TF::MatMulOp>(),
|
TypeID::get<TF::MatMulOp>(),
|
||||||
TypeID::get<TF::MatrixDiagV3Op>(),
|
TypeID::get<TF::MatrixDiagV3Op>(),
|
||||||
TypeID::get<TF::MatrixInverseOp>(),
|
TypeID::get<TF::MatrixInverseOp>(),
|
||||||
@ -248,6 +251,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
|||||||
TypeID::get<TF::TensorScatterAddOp>(),
|
TypeID::get<TF::TensorScatterAddOp>(),
|
||||||
TypeID::get<TF::TensorScatterSubOp>(),
|
TypeID::get<TF::TensorScatterSubOp>(),
|
||||||
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
|
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
|
||||||
|
TypeID::get<TF::TopKUniqueOp>(),
|
||||||
|
TypeID::get<TF::TopKWithUniqueOp>(),
|
||||||
TypeID::get<TF::TransposeOp>(),
|
TypeID::get<TF::TransposeOp>(),
|
||||||
TypeID::get<TF::TridiagonalSolveOp>(),
|
TypeID::get<TF::TridiagonalSolveOp>(),
|
||||||
TypeID::get<TF::TruncateDivOp>(),
|
TypeID::get<TF::TruncateDivOp>(),
|
||||||
|
@ -121,7 +121,7 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
|
|||||||
// Run all HLO passes to produce an optimized module.
|
// Run all HLO passes to produce an optimized module.
|
||||||
auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement(
|
auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement(
|
||||||
std::move(hlo_module), backend->default_stream_executor(),
|
std::move(hlo_module), backend->default_stream_executor(),
|
||||||
backend->memory_allocator(), optimize_xla_hlo);
|
optimize_xla_hlo, {backend->memory_allocator()});
|
||||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(),
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(),
|
||||||
"running XLA pass pipeline");
|
"running XLA pass pipeline");
|
||||||
std::unique_ptr<HloModule> optimized_hlo_module =
|
std::unique_ptr<HloModule> optimized_hlo_module =
|
||||||
|
@ -115,6 +115,16 @@ class ExecutableBuildOptions {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Thread pool for parallel compilation.
|
||||||
|
tensorflow::thread::ThreadPool* compile_thread_pool() const {
|
||||||
|
return compile_thread_pool_;
|
||||||
|
}
|
||||||
|
ExecutableBuildOptions& set_run_backend_only(
|
||||||
|
tensorflow::thread::ThreadPool* compile_thread_pool) {
|
||||||
|
compile_thread_pool_ = compile_thread_pool;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int device_ordinal_ = -1;
|
int device_ordinal_ = -1;
|
||||||
Shape result_layout_;
|
Shape result_layout_;
|
||||||
@ -128,6 +138,7 @@ class ExecutableBuildOptions {
|
|||||||
absl::optional<DeviceAssignment> device_assignment_;
|
absl::optional<DeviceAssignment> device_assignment_;
|
||||||
bool alias_passthrough_params_ = false;
|
bool alias_passthrough_params_ = false;
|
||||||
bool run_backend_only_ = false;
|
bool run_backend_only_ = false;
|
||||||
|
tensorflow::thread::ThreadPool* compile_thread_pool_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -3347,6 +3347,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
|||||||
// contant False if dimension is static.
|
// contant False if dimension is static.
|
||||||
// - Reduce: Convert to reduce or.
|
// - Reduce: Convert to reduce or.
|
||||||
// - Constant: Convert to constant False.
|
// - Constant: Convert to constant False.
|
||||||
|
// - Reshape, slice, transpose, pad:
|
||||||
|
// Convert into predicate type with same opcode.
|
||||||
// - Other ops: Not supported.
|
// - Other ops: Not supported.
|
||||||
// Create the instruction for the new handle.
|
// Create the instruction for the new handle.
|
||||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
||||||
@ -3449,6 +3451,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
|||||||
case HloOpcode::kBroadcast:
|
case HloOpcode::kBroadcast:
|
||||||
case HloOpcode::kConcatenate:
|
case HloOpcode::kConcatenate:
|
||||||
case HloOpcode::kReshape:
|
case HloOpcode::kReshape:
|
||||||
|
case HloOpcode::kPad:
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kGetDimensionSize: {
|
case HloOpcode::kGetDimensionSize: {
|
||||||
int64 dimension = instr_proto->dimensions(0);
|
int64 dimension = instr_proto->dimensions(0);
|
||||||
|
@ -26,8 +26,8 @@ static const char kCpuPlatformName[] = "cpu";
|
|||||||
|
|
||||||
CpuDevice::CpuDevice(int id,
|
CpuDevice::CpuDevice(int id,
|
||||||
std::unique_ptr<LocalDeviceState> local_device_state)
|
std::unique_ptr<LocalDeviceState> local_device_state)
|
||||||
: PjRtDevice(id, std::move(local_device_state),
|
: PjRtStreamExecutorDevice(id, std::move(local_device_state),
|
||||||
/*device_kind=*/kCpuPlatformName) {}
|
/*device_kind=*/kCpuPlatformName) {}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||||
@ -40,7 +40,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
|||||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||||
ClientLibrary::GetOrCreateLocalClient(options));
|
ClientLibrary::GetOrCreateLocalClient(options));
|
||||||
|
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||||
for (int i = 0; i < client->device_count(); ++i) {
|
for (int i = 0; i < client->device_count(); ++i) {
|
||||||
se::StreamExecutorConfig config;
|
se::StreamExecutorConfig config;
|
||||||
config.ordinal = i;
|
config.ordinal = i;
|
||||||
@ -57,11 +57,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
|||||||
devices.push_back(std::move(device));
|
devices.push_back(std::move(device));
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<PjRtClient>(
|
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
|
||||||
kCpuName, client, std::move(devices), /*host_id=*/0,
|
kCpuName, client, std::move(devices), /*host_id=*/0,
|
||||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||||
/*should_stage_host_to_device_transfers=*/false,
|
/*should_stage_host_to_device_transfers=*/false,
|
||||||
/*gpu_run_options=*/nullptr);
|
/*gpu_run_options=*/nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
class CpuDevice : public PjRtDevice {
|
class CpuDevice : public PjRtStreamExecutorDevice {
|
||||||
public:
|
public:
|
||||||
CpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state);
|
CpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state);
|
||||||
};
|
};
|
||||||
|
@ -35,9 +35,9 @@ namespace xla {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// A custom PjRtClient that overrides the device assignment method.
|
// A custom PjRtClient that overrides the device assignment method.
|
||||||
class GpuClient : public xla::PjRtClient {
|
class GpuClient : public xla::PjRtStreamExecutorClient {
|
||||||
public:
|
public:
|
||||||
using xla::PjRtClient::PjRtClient;
|
using xla::PjRtStreamExecutorClient::PjRtStreamExecutorClient;
|
||||||
|
|
||||||
xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
|
xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
|
||||||
int num_replicas, int num_partitions) const override;
|
int num_replicas, int num_partitions) const override;
|
||||||
@ -55,7 +55,8 @@ xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(
|
|||||||
return assignment;
|
return assignment;
|
||||||
}
|
}
|
||||||
// Fallback to default global device assignment if we can't run locally.
|
// Fallback to default global device assignment if we can't run locally.
|
||||||
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
|
return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
|
||||||
|
num_partitions);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Builds an xla::LocalClient for the GPU platform.
|
// Builds an xla::LocalClient for the GPU platform.
|
||||||
@ -225,9 +226,9 @@ StatusOr<std::string> NcclIdStore::GetNcclUniqueId(
|
|||||||
return result.first->second;
|
return result.first->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
|
||||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||||
for (auto& local_device : local_device_states) {
|
for (auto& local_device : local_device_states) {
|
||||||
int device_ordinal = local_device->device_ordinal();
|
int device_ordinal = local_device->device_ordinal();
|
||||||
const se::DeviceDescription& description =
|
const se::DeviceDescription& description =
|
||||||
@ -243,7 +244,7 @@ std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
|
|||||||
Status BuildDistributedDevices(
|
Status BuildDistributedDevices(
|
||||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
|
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
|
||||||
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
|
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
|
||||||
std::vector<std::unique_ptr<PjRtDevice>>* devices,
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
|
||||||
gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
|
gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
|
||||||
LocalTopologyProto local_topology;
|
LocalTopologyProto local_topology;
|
||||||
local_topology.set_node_id(node_id);
|
local_topology.set_node_id(node_id);
|
||||||
@ -306,8 +307,8 @@ Status BuildDistributedDevices(
|
|||||||
GpuDevice::GpuDevice(int id,
|
GpuDevice::GpuDevice(int id,
|
||||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
std::unique_ptr<LocalDeviceState> local_device_state,
|
||||||
std::string device_kind, int node_id)
|
std::string device_kind, int node_id)
|
||||||
: PjRtDevice(id, std::move(local_device_state), std::move(device_kind),
|
: PjRtStreamExecutorDevice(id, std::move(local_device_state),
|
||||||
node_id) {}
|
std::move(device_kind), node_id) {}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
|
StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
|
||||||
bool asynchronous, const GpuAllocatorConfig& allocator_config,
|
bool asynchronous, const GpuAllocatorConfig& allocator_config,
|
||||||
@ -322,7 +323,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
|
|||||||
auto host_memory_allocator =
|
auto host_memory_allocator =
|
||||||
GetGpuHostAllocator(local_device_states.front()->executor());
|
GetGpuHostAllocator(local_device_states.front()->executor());
|
||||||
|
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||||
auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
|
auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
|
||||||
if (distributed_client) {
|
if (distributed_client) {
|
||||||
TF_RETURN_IF_ERROR(BuildDistributedDevices(
|
TF_RETURN_IF_ERROR(BuildDistributedDevices(
|
||||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
class GpuDevice : public PjRtDevice {
|
class GpuDevice : public PjRtStreamExecutorDevice {
|
||||||
public:
|
public:
|
||||||
GpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
GpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
||||||
std::string device_kind, int node_id);
|
std::string device_kind, int node_id);
|
||||||
|
@ -26,8 +26,8 @@ static const char kInterpreterPlatformName[] = "interpreter";
|
|||||||
|
|
||||||
InterpreterDevice::InterpreterDevice(
|
InterpreterDevice::InterpreterDevice(
|
||||||
int id, std::unique_ptr<LocalDeviceState> local_device_state)
|
int id, std::unique_ptr<LocalDeviceState> local_device_state)
|
||||||
: PjRtDevice(id, std::move(local_device_state),
|
: PjRtStreamExecutorDevice(id, std::move(local_device_state),
|
||||||
/*device_kind=*/kInterpreterPlatformName) {}
|
/*device_kind=*/kInterpreterPlatformName) {}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
||||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||||
@ -41,7 +41,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
|||||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||||
ClientLibrary::GetOrCreateLocalClient(options));
|
ClientLibrary::GetOrCreateLocalClient(options));
|
||||||
|
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||||
se::StreamExecutor* executor =
|
se::StreamExecutor* executor =
|
||||||
client->backend().stream_executor(0).ValueOrDie();
|
client->backend().stream_executor(0).ValueOrDie();
|
||||||
auto device_state = absl::make_unique<LocalDeviceState>(
|
auto device_state = absl::make_unique<LocalDeviceState>(
|
||||||
@ -51,11 +51,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
|||||||
absl::make_unique<InterpreterDevice>(0, std::move(device_state));
|
absl::make_unique<InterpreterDevice>(0, std::move(device_state));
|
||||||
devices.push_back(std::move(device));
|
devices.push_back(std::move(device));
|
||||||
|
|
||||||
return std::make_unique<PjRtClient>(
|
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
|
||||||
"interpreter", client, std::move(devices), /*host_id=*/0,
|
"interpreter", client, std::move(devices), /*host_id=*/0,
|
||||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||||
/*should_stage_host_to_device_transfers=*/false,
|
/*should_stage_host_to_device_transfers=*/false,
|
||||||
/*gpu_run_options=*/nullptr);
|
/*gpu_run_options=*/nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
class InterpreterDevice : public PjRtDevice {
|
class InterpreterDevice : public PjRtStreamExecutorDevice {
|
||||||
public:
|
public:
|
||||||
InterpreterDevice(int id,
|
InterpreterDevice(int id,
|
||||||
std::unique_ptr<LocalDeviceState> local_device_state);
|
std::unique_ptr<LocalDeviceState> local_device_state);
|
||||||
|
@ -114,21 +114,22 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
PjRtPlatformId PjRtDevice::platform_id() const {
|
PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
|
||||||
return client_->platform_id();
|
return client_->platform_id();
|
||||||
}
|
}
|
||||||
const std::string& PjRtDevice::platform_name() const {
|
const std::string& PjRtStreamExecutorDevice::platform_name() const {
|
||||||
return client_->platform_name();
|
return client_->platform_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<LocalDeviceState*> PjRtDevice::GetLocalDeviceState() const {
|
StatusOr<LocalDeviceState*> PjRtStreamExecutorDevice::GetLocalDeviceState()
|
||||||
|
const {
|
||||||
if (local_device_state_) {
|
if (local_device_state_) {
|
||||||
return local_device_state_.get();
|
return local_device_state_.get();
|
||||||
}
|
}
|
||||||
return InvalidArgument("Device %s is not a local device.", DebugString());
|
return InvalidArgument("Device %s is not a local device.", DebugString());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PjRtDevice::DebugString() const {
|
std::string PjRtStreamExecutorDevice::DebugString() const {
|
||||||
return absl::StrCat(platform_name(), ":", id());
|
return absl::StrCat(platform_name(), ":", id());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,14 +154,15 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
|
|||||||
devices[replica].size(), replica, devices[0].size());
|
devices[replica].size(), replica, devices[0].size());
|
||||||
}
|
}
|
||||||
for (int partition = 0; partition < devices[replica].size(); ++partition) {
|
for (int partition = 0; partition < devices[replica].size(); ++partition) {
|
||||||
if (devices[0][0]->platform_id() !=
|
if (devices[0][0]->client()->platform_id() !=
|
||||||
devices[replica][partition]->platform_id()) {
|
devices[replica][partition]->client()->platform_id()) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Device assignment passed to Compile() must have devices of a "
|
"Device assignment passed to Compile() must have devices of a "
|
||||||
"single kind, got %s for replica 0 partition 0 and %s for replica "
|
"single kind, got %s for replica 0 partition 0 and %s for replica "
|
||||||
"%d partition %d.",
|
"%d partition %d.",
|
||||||
devices[0][0]->platform_name(),
|
devices[0][0]->client()->platform_name(),
|
||||||
devices[replica][partition]->platform_name(), replica, partition);
|
devices[replica][partition]->client()->platform_name(), replica,
|
||||||
|
partition);
|
||||||
}
|
}
|
||||||
xla_assignment(replica, partition) = devices[replica][partition]->id();
|
xla_assignment(replica, partition) = devices[replica][partition]->id();
|
||||||
}
|
}
|
||||||
@ -182,9 +184,9 @@ class CpuAllocator : public tensorflow::Allocator {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
PjRtClient::PjRtClient(
|
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
|
||||||
std::string platform_name, LocalClient* client,
|
std::string platform_name, LocalClient* client,
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id,
|
||||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||||
bool should_stage_host_to_device_transfers,
|
bool should_stage_host_to_device_transfers,
|
||||||
@ -193,7 +195,7 @@ PjRtClient::PjRtClient(
|
|||||||
platform_name_(std::move(platform_name)),
|
platform_name_(std::move(platform_name)),
|
||||||
client_(client),
|
client_(client),
|
||||||
host_memory_allocator_(std::move(host_memory_allocator)),
|
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||||
devices_(std::move(devices)),
|
owned_devices_(std::move(devices)),
|
||||||
host_id_(host_id),
|
host_id_(host_id),
|
||||||
owned_allocator_(std::move(allocator)),
|
owned_allocator_(std::move(allocator)),
|
||||||
should_stage_host_to_device_transfers_(
|
should_stage_host_to_device_transfers_(
|
||||||
@ -211,12 +213,14 @@ PjRtClient::PjRtClient(
|
|||||||
host_memory_allocator_ = std::make_unique<CpuAllocator>();
|
host_memory_allocator_ = std::make_unique<CpuAllocator>();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const std::unique_ptr<PjRtDevice>& device : devices_) {
|
for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
|
||||||
|
owned_devices_) {
|
||||||
|
devices_.push_back(device.get());
|
||||||
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
||||||
<< "Duplicate device id: " << device->id();
|
<< "Duplicate device id: " << device->id();
|
||||||
|
|
||||||
if (device->IsLocalDevice()) {
|
if (device->IsAddressable()) {
|
||||||
int idx = device->local_device_id();
|
int idx = device->local_hardware_id();
|
||||||
if (idx >= local_devices_.size()) {
|
if (idx >= local_devices_.size()) {
|
||||||
local_devices_.resize(idx + 1);
|
local_devices_.resize(idx + 1);
|
||||||
}
|
}
|
||||||
@ -230,13 +234,14 @@ PjRtClient::PjRtClient(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment(
|
StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
|
||||||
int num_replicas, int num_partitions) const {
|
int num_replicas, int num_partitions) const {
|
||||||
return client_->backend().computation_placer()->AssignDevices(num_replicas,
|
return client_->backend().computation_placer()->AssignDevices(num_replicas,
|
||||||
num_partitions);
|
num_partitions);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() {
|
std::unique_ptr<HloCostAnalysis>
|
||||||
|
PjRtStreamExecutorClient::GetHloCostAnalysis() {
|
||||||
return absl::make_unique<HloCostAnalysis>(
|
return absl::make_unique<HloCostAnalysis>(
|
||||||
client_->backend().compiler()->ShapeSizeBytesFunction());
|
client_->backend().compiler()->ShapeSizeBytesFunction());
|
||||||
}
|
}
|
||||||
@ -346,12 +351,13 @@ StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
|
|||||||
return InvalidArgument("Can't make a buffer from an empty tuple");
|
return InvalidArgument("Can't make a buffer from an empty tuple");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
|
||||||
TransferManager* transfer_manager =
|
TransferManager* transfer_manager =
|
||||||
client->client()->backend().transfer_manager();
|
se_client->client()->backend().transfer_manager();
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
|
||||||
ScopedShapedBuffer dst_buffer,
|
transfer_manager->AllocateScopedShapedBuffer(
|
||||||
transfer_manager->AllocateScopedShapedBuffer(
|
on_host_shape, se_client->allocator(),
|
||||||
on_host_shape, client->allocator(), local_device->device_ordinal()));
|
local_device->device_ordinal()));
|
||||||
if (local_device->allocation_model() ==
|
if (local_device->allocation_model() ==
|
||||||
LocalDeviceState::kComputeSynchronized) {
|
LocalDeviceState::kComputeSynchronized) {
|
||||||
if (copy_stream == nullptr) {
|
if (copy_stream == nullptr) {
|
||||||
@ -543,18 +549,21 @@ void PjRtBuffer::ScopedHold::AddToInput(
|
|||||||
|
|
||||||
bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; }
|
bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; }
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||||
|
PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||||
const void* data, const Shape& shape,
|
const void* data, const Shape& shape,
|
||||||
HostBufferSemantics host_buffer_semantics,
|
HostBufferSemantics host_buffer_semantics,
|
||||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
|
std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer");
|
tensorflow::profiler::TraceMe traceme(
|
||||||
VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString()
|
"PjRtStreamExecutorClient::BufferFromHostBuffer");
|
||||||
<< " device: " << device->DebugString();
|
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
|
||||||
|
<< shape.ToString() << " device: " << device->DebugString();
|
||||||
if (shape.IsTuple()) {
|
if (shape.IsTuple()) {
|
||||||
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
|
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
device->GetLocalDeviceState());
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
->GetLocalDeviceState());
|
||||||
int64 size = ShapeUtil::ByteSizeOf(shape);
|
int64 size = ShapeUtil::ByteSizeOf(shape);
|
||||||
|
|
||||||
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||||
@ -708,20 +717,23 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
|||||||
return py_buffer;
|
return py_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||||
const Shape& shape, PjRtDevice* device) {
|
PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
|
||||||
|
PjRtDevice* device) {
|
||||||
return CreateUninitializedBuffer(shape, device, nullptr);
|
return CreateUninitializedBuffer(shape, device, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||||
|
PjRtStreamExecutorClient::CreateUninitializedBuffer(
|
||||||
const Shape& shape, PjRtDevice* device,
|
const Shape& shape, PjRtDevice* device,
|
||||||
std::shared_ptr<BufferSequencingEvent> definition_event) {
|
std::shared_ptr<BufferSequencingEvent> definition_event) {
|
||||||
tensorflow::profiler::TraceMe traceme(
|
tensorflow::profiler::TraceMe traceme(
|
||||||
"PjRtClient::CreateUninitializedBuffer");
|
"PjRtStreamExecutorClient::CreateUninitializedBuffer");
|
||||||
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
|
VLOG(2) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
|
||||||
<< shape.ToString() << " device: " << device->DebugString();
|
<< shape.ToString() << " device: " << device->DebugString();
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
device->GetLocalDeviceState());
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
->GetLocalDeviceState());
|
||||||
|
|
||||||
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||||
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
||||||
@ -733,13 +745,16 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
|||||||
definition_event);
|
definition_event);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||||
const LiteralSlice& literal, PjRtDevice* device) {
|
PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral");
|
PjRtDevice* device) {
|
||||||
VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
|
tensorflow::profiler::TraceMe traceme(
|
||||||
|
"PjRtStreamExecutorClient::BufferFromHostLiteral");
|
||||||
|
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
|
||||||
<< literal.shape().ToString() << " device: " << device->DebugString();
|
<< literal.shape().ToString() << " device: " << device->DebugString();
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
device->GetLocalDeviceState());
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
->GetLocalDeviceState());
|
||||||
|
|
||||||
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
@ -792,7 +807,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
|||||||
return py_buffer;
|
return py_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PjRtClient::MakeCrossHostReceiveBuffers(
|
void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
|
||||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||||
PjRtCrossHostRecvNotifier&& notifier) {
|
PjRtCrossHostRecvNotifier&& notifier) {
|
||||||
if (shapes.empty()) {
|
if (shapes.empty()) {
|
||||||
@ -801,7 +816,9 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto local_device_or = device->GetLocalDeviceState();
|
auto local_device_or =
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
->GetLocalDeviceState();
|
||||||
if (!local_device_or.ok()) {
|
if (!local_device_or.ok()) {
|
||||||
notifier(local_device_or.status());
|
notifier(local_device_or.status());
|
||||||
return;
|
return;
|
||||||
@ -828,27 +845,30 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Transfer the given literal to the infeed queue of the given local device.
|
// Transfer the given literal to the infeed queue of the given local device.
|
||||||
Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const {
|
Status PjRtStreamExecutorDevice::TransferToInfeed(
|
||||||
|
const LiteralSlice& literal) const {
|
||||||
// Only support infeed to local device.
|
// Only support infeed to local device.
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||||
return local_device->client()->TransferToInfeedLocal(
|
return local_device->client()->TransferToInfeedLocal(
|
||||||
literal, local_device->device_ordinal());
|
literal, local_device->device_ordinal());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Literal> PjRtDevice::TransferFromOutfeed(const Shape& shape) const {
|
StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed(
|
||||||
|
const Shape& shape) const {
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||||
return local_device->client()->TransferFromOutfeedLocal(
|
return local_device->client()->TransferFromOutfeedLocal(
|
||||||
shape, local_device->device_ordinal());
|
shape, local_device->device_ordinal());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<PjRtDevice*> PjRtClient::LookupLocalDevice(int local_device_id) const {
|
StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
|
||||||
|
int local_hardware_id) const {
|
||||||
for (auto* device : local_devices_) {
|
for (auto* device : local_devices_) {
|
||||||
if (local_device_id == device->local_device_id()) {
|
if (local_hardware_id == device->local_hardware_id()) {
|
||||||
return device;
|
return device;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return InvalidArgument("No matching device found for local_device_id %d",
|
return InvalidArgument("No matching device found for local_hardware_id %d",
|
||||||
local_device_id);
|
local_hardware_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
@ -873,7 +893,8 @@ PjRtBuffer::~PjRtBuffer() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64 PjRtBuffer::OnDeviceSizeInBytes() const {
|
int64 PjRtBuffer::OnDeviceSizeInBytes() const {
|
||||||
return client_->client()
|
return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||||
|
->client()
|
||||||
->backend()
|
->backend()
|
||||||
.transfer_manager()
|
.transfer_manager()
|
||||||
->GetByteSizeRequirement(on_device_shape_);
|
->GetByteSizeRequirement(on_device_shape_);
|
||||||
@ -919,7 +940,9 @@ StatusOr<std::shared_ptr<TrackedDeviceBuffer>> PjRtBuffer::Release(
|
|||||||
// the final set of usage events.
|
// the final set of usage events.
|
||||||
events = device_buffer->LockUseAndTransferUsageEvents();
|
events = device_buffer->LockUseAndTransferUsageEvents();
|
||||||
}
|
}
|
||||||
LocalDeviceState* local_device_state = device_->local_device_state();
|
LocalDeviceState* local_device_state =
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
|
->local_device_state();
|
||||||
if (wait_for_operations_to_complete) {
|
if (wait_for_operations_to_complete) {
|
||||||
// Block the host until all usage events have completed. Usage events
|
// Block the host until all usage events have completed. Usage events
|
||||||
// dominate definition events, so this also waits for the buffer to be
|
// dominate definition events, so this also waits for the buffer to be
|
||||||
@ -1080,7 +1103,9 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
|||||||
}
|
}
|
||||||
ScopedHold device_buffer(this, ScopedHold::kUsage);
|
ScopedHold device_buffer(this, ScopedHold::kUsage);
|
||||||
std::shared_ptr<HostValue> host_value;
|
std::shared_ptr<HostValue> host_value;
|
||||||
LocalDeviceState* local_device = device_->local_device_state();
|
LocalDeviceState* local_device =
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
|
->local_device_state();
|
||||||
se::Stream* stream = local_device->GetDeviceToHostStream();
|
se::Stream* stream = local_device->GetDeviceToHostStream();
|
||||||
const xla::Layout& host_layout =
|
const xla::Layout& host_layout =
|
||||||
layout.has_value() ? layout.value() : on_host_shape_.layout();
|
layout.has_value() ? layout.value() : on_host_shape_.layout();
|
||||||
@ -1122,12 +1147,16 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
|||||||
host_value->value = std::make_shared<Literal>(host_shape);
|
host_value->value = std::make_shared<Literal>(host_shape);
|
||||||
ShapedBuffer shaped_buffer =
|
ShapedBuffer shaped_buffer =
|
||||||
device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
|
device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
|
||||||
client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
|
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||||
stream, shaped_buffer, host_value->value.get(),
|
->client()
|
||||||
[host_value](Status done_status) {
|
->backend()
|
||||||
host_value->status = done_status;
|
.transfer_manager()
|
||||||
host_value->ready.Notify();
|
->TransferLiteralFromDevice(stream, shaped_buffer,
|
||||||
});
|
host_value->value.get(),
|
||||||
|
[host_value](Status done_status) {
|
||||||
|
host_value->status = done_status;
|
||||||
|
host_value->ready.Notify();
|
||||||
|
});
|
||||||
|
|
||||||
auto usage_event = std::make_shared<BufferSequencingEvent>();
|
auto usage_event = std::make_shared<BufferSequencingEvent>();
|
||||||
StatusOr<EventPool::Handle> event_or =
|
StatusOr<EventPool::Handle> event_or =
|
||||||
@ -1156,7 +1185,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
|||||||
|
|
||||||
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
|
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
|
||||||
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral");
|
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::ToLiteral");
|
||||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
|
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
|
||||||
CopyToHostAsyncInternal(discard_cached_copy, layout));
|
CopyToHostAsyncInternal(discard_cached_copy, layout));
|
||||||
if (host_value == nullptr) {
|
if (host_value == nullptr) {
|
||||||
@ -1241,8 +1270,9 @@ PjRtBuffer::CopyToDeviceHelper(
|
|||||||
// StallStreamOnError only makes sure the destination device is ok, so
|
// StallStreamOnError only makes sure the destination device is ok, so
|
||||||
// make sure that the src buffer remains valid until after any transfers
|
// make sure that the src buffer remains valid until after any transfers
|
||||||
// have completed.
|
// have completed.
|
||||||
device_->local_device_state()->ThenRelease(transfer_stream,
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
src_device_buffer);
|
->local_device_state()
|
||||||
|
->ThenRelease(transfer_stream, src_device_buffer);
|
||||||
}
|
}
|
||||||
return copy_event_or.status();
|
return copy_event_or.status();
|
||||||
}
|
}
|
||||||
@ -1265,14 +1295,20 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
|||||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
|
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
|
||||||
return dst_device->client()->BufferFromHostBuffer(
|
return dst_device->client()->BufferFromHostBuffer(
|
||||||
literal->untyped_data(), literal->shape(),
|
literal->untyped_data(), literal->shape(),
|
||||||
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
|
PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy, nullptr,
|
||||||
|
dst_device);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
|
TF_ASSIGN_OR_RETURN(
|
||||||
dst_device->GetLocalDeviceState());
|
LocalDeviceState * dst_local_device,
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
|
||||||
|
->GetLocalDeviceState());
|
||||||
LocalDeviceState* transfer_local_device =
|
LocalDeviceState* transfer_local_device =
|
||||||
client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
|
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||||
: dst_local_device;
|
->EnqueueD2DTransfersOnSrcStream()
|
||||||
|
? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
|
->local_device_state()
|
||||||
|
: dst_local_device;
|
||||||
CHECK_EQ(dst_local_device->allocation_model(),
|
CHECK_EQ(dst_local_device->allocation_model(),
|
||||||
transfer_local_device->allocation_model());
|
transfer_local_device->allocation_model());
|
||||||
|
|
||||||
@ -1310,7 +1346,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
|||||||
// alternative is to ensure, before freeing the buffer, that the compute
|
// alternative is to ensure, before freeing the buffer, that the compute
|
||||||
// stream is synchronized past the transfer, but it seems better to hold onto
|
// stream is synchronized past the transfer, but it seems better to hold onto
|
||||||
// the buffer too long than to stall the compute stream.
|
// the buffer too long than to stall the compute stream.
|
||||||
RecordUsage(std::move(src_device_buffer), device_->local_device_state(),
|
RecordUsage(std::move(src_device_buffer),
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
|
->local_device_state(),
|
||||||
transfer_local_device, event, transfer_stream,
|
transfer_local_device, event, transfer_stream,
|
||||||
/*prefer_to_retain_reference=*/true);
|
/*prefer_to_retain_reference=*/true);
|
||||||
|
|
||||||
@ -1318,7 +1356,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) {
|
Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) {
|
||||||
return client_->CopyToRemoteDevice(this, serialized_descriptor);
|
return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||||
|
->CopyToRemoteDevice(this, serialized_descriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status PjRtBuffer::BlockHostUntilReady() {
|
Status PjRtBuffer::BlockHostUntilReady() {
|
||||||
@ -1332,7 +1371,9 @@ Status PjRtBuffer::BlockHostUntilReady() {
|
|||||||
}
|
}
|
||||||
device_buffer = device_buffer_;
|
device_buffer = device_buffer_;
|
||||||
}
|
}
|
||||||
LocalDeviceState* local_device_state = device_->local_device_state();
|
LocalDeviceState* local_device_state =
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
|
->local_device_state();
|
||||||
std::unique_ptr<se::Stream> stream;
|
std::unique_ptr<se::Stream> stream;
|
||||||
for (auto& event : device_buffer->definition_events()) {
|
for (auto& event : device_buffer->definition_events()) {
|
||||||
if (!event->IsComplete()) {
|
if (!event->IsComplete()) {
|
||||||
@ -1378,9 +1419,13 @@ StatusOr<TupleHandle> MakeTupleHelper(
|
|||||||
Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes);
|
Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes);
|
||||||
Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
|
Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
|
||||||
|
|
||||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
se::DeviceMemoryAllocator* allocator =
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)->allocator();
|
||||||
TransferManager* transfer_manager =
|
TransferManager* transfer_manager =
|
||||||
client->client()->backend().transfer_manager();
|
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)
|
||||||
|
->client()
|
||||||
|
->backend()
|
||||||
|
.transfer_manager();
|
||||||
se::Stream* stream = local_device->host_to_device_stream();
|
se::Stream* stream = local_device->host_to_device_stream();
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
se::OwningDeviceMemory root_table_memory,
|
se::OwningDeviceMemory root_table_memory,
|
||||||
@ -1444,14 +1489,6 @@ std::unique_ptr<PjRtBuffer> OutputBufferHelper(
|
|||||||
/*prefer_to_retain_reference=*/false);
|
/*prefer_to_retain_reference=*/false);
|
||||||
return pjrt_buffer;
|
return pjrt_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) {
|
|
||||||
auto it = client.id_to_device().find(device_id);
|
|
||||||
CHECK(it != client.id_to_device().end())
|
|
||||||
<< "Unknown device id: " << device_id;
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
||||||
@ -1459,7 +1496,8 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
|||||||
bool parameter_is_tupled_arguments,
|
bool parameter_is_tupled_arguments,
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||||
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
||||||
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client)
|
std::vector<PjRtDevice*> addressable_devices,
|
||||||
|
PjRtStreamExecutorClient* client)
|
||||||
: client_(client),
|
: client_(client),
|
||||||
device_assignment_(std::move(device_assignment)),
|
device_assignment_(std::move(device_assignment)),
|
||||||
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
|
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
|
||||||
@ -1482,7 +1520,7 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
|||||||
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
|
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
|
||||||
<< device_assignment_->ToString();
|
<< device_assignment_->ToString();
|
||||||
CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
|
CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
|
||||||
CHECK_LE(addressable_devices_.size(), client_->local_device_count())
|
CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
|
||||||
<< "Inconsistent local device count.";
|
<< "Inconsistent local device count.";
|
||||||
num_partitions = device_assignment_->computation_count();
|
num_partitions = device_assignment_->computation_count();
|
||||||
}
|
}
|
||||||
@ -1584,7 +1622,7 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
|
|||||||
absl::Span<const PjRtBuffer::ScopedHold> device_buffers,
|
absl::Span<const PjRtBuffer::ScopedHold> device_buffers,
|
||||||
absl::flat_hash_set<BufferSequencingEvent*>& events) const {
|
absl::flat_hash_set<BufferSequencingEvent*>& events) const {
|
||||||
std::vector<ExecutionInput> execution_inputs;
|
std::vector<ExecutionInput> execution_inputs;
|
||||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||||
// Lift tuple_handle outside the conditional so that the event it returns is
|
// Lift tuple_handle outside the conditional so that the event it returns is
|
||||||
// not destroyed until after the loop below that waits on events.
|
// not destroyed until after the loop below that waits on events.
|
||||||
absl::optional<TupleHandle> tuple_handle;
|
absl::optional<TupleHandle> tuple_handle;
|
||||||
@ -1607,8 +1645,10 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
|
|||||||
execution_input.MutableBuffers()->begin();
|
execution_input.MutableBuffers()->begin();
|
||||||
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
|
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
|
||||||
execution_input.MutableBuffers()->end();
|
execution_input.MutableBuffers()->end();
|
||||||
device_buffers[i].AddToInput(&input_iterator, iterator_end,
|
device_buffers[i].AddToInput(
|
||||||
&execution_input, client_->allocator());
|
&input_iterator, iterator_end, &execution_input,
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||||
|
->allocator());
|
||||||
CHECK(input_iterator == iterator_end);
|
CHECK(input_iterator == iterator_end);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1628,8 +1668,10 @@ StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
|
|||||||
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
||||||
PjRtDevice* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
|
PjRtDevice* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment) const {
|
std::shared_ptr<DeviceAssignment> device_assignment) const {
|
||||||
int device_ordinal = device->local_device_state()->device_ordinal();
|
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
->local_device_state()
|
||||||
|
->device_ordinal();
|
||||||
|
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||||
tensorflow::profiler::TraceMeConsumer activity(
|
tensorflow::profiler::TraceMeConsumer activity(
|
||||||
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
|
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
|
||||||
run_id.ToInt());
|
run_id.ToInt());
|
||||||
@ -1765,7 +1807,7 @@ PjRtStreamExecutorExecutable::MakeOutputBuffers(
|
|||||||
std::shared_ptr<BufferSequencingEvent> definition_event,
|
std::shared_ptr<BufferSequencingEvent> definition_event,
|
||||||
PjRtDevice* device) const {
|
PjRtDevice* device) const {
|
||||||
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
|
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
|
||||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||||
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
|
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
|
||||||
int tuple_count = result_buffer.on_host_shape().tuple_shapes_size();
|
int tuple_count = result_buffer.on_host_shape().tuple_shapes_size();
|
||||||
outputs.reserve(tuple_count);
|
outputs.reserve(tuple_count);
|
||||||
@ -1802,7 +1844,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
|||||||
if (device == nullptr) {
|
if (device == nullptr) {
|
||||||
CHECK(device_assignment_ != nullptr);
|
CHECK(device_assignment_ != nullptr);
|
||||||
const int device_id = (*device_assignment_)(replica, partition);
|
const int device_id = (*device_assignment_)(replica, partition);
|
||||||
device = LookupDevice(*client_, device_id);
|
TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
|
||||||
device_assignment = device_assignment_;
|
device_assignment = device_assignment_;
|
||||||
} else {
|
} else {
|
||||||
CHECK(device_assignment_ == nullptr);
|
CHECK(device_assignment_ == nullptr);
|
||||||
@ -1814,7 +1856,9 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
CHECK_EQ(device->host_id(), client_->host_id());
|
CHECK_EQ(device->host_id(), client_->host_id());
|
||||||
int device_ordinal = device->local_device_state()->device_ordinal();
|
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
->local_device_state()
|
||||||
|
->device_ordinal();
|
||||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
|
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
|
||||||
VLOG(3) << "Replica " << replica << ", partition " << partition
|
VLOG(3) << "Replica " << replica << ", partition " << partition
|
||||||
<< " mapped to device ordinal for execution: " << device_ordinal;
|
<< " mapped to device ordinal for execution: " << device_ordinal;
|
||||||
@ -1836,7 +1880,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
|||||||
ScopedShapedBuffer result_buffer =
|
ScopedShapedBuffer result_buffer =
|
||||||
result_buffer_or_status.ConsumeValueOrDie();
|
result_buffer_or_status.ConsumeValueOrDie();
|
||||||
|
|
||||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||||
se::Stream* stream = device_state->compute_stream();
|
se::Stream* stream = device_state->compute_stream();
|
||||||
StatusOr<EventPool::Handle> event_or =
|
StatusOr<EventPool::Handle> event_or =
|
||||||
device_state->event_pool().ThenAllocateAndRecordEvent(stream);
|
device_state->event_pool().ThenAllocateAndRecordEvent(stream);
|
||||||
@ -1922,7 +1966,9 @@ PjRtStreamExecutorExecutable::Execute(
|
|||||||
const int replica = addressable_device_logical_ids_[i].replica;
|
const int replica = addressable_device_logical_ids_[i].replica;
|
||||||
const int partition = addressable_device_logical_ids_[i].partition;
|
const int partition = addressable_device_logical_ids_[i].partition;
|
||||||
PjRtDevice* device = addressable_devices_[i];
|
PjRtDevice* device = addressable_devices_[i];
|
||||||
const LocalDeviceState& device_state = *device->local_device_state();
|
const LocalDeviceState& device_state =
|
||||||
|
*tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
->local_device_state();
|
||||||
device_state.execute_thread()->Schedule([&, replica, partition, i] {
|
device_state.execute_thread()->Schedule([&, replica, partition, i] {
|
||||||
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
|
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
|
||||||
run_id, options);
|
run_id, options);
|
||||||
@ -2131,9 +2177,9 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
|
||||||
const XlaComputation& computation, CompileOptions options) {
|
const XlaComputation& computation, CompileOptions options) {
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtClient::Compile");
|
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
|
||||||
|
|
||||||
ExecutableBuildOptions& build_options = options.executable_build_options;
|
ExecutableBuildOptions& build_options = options.executable_build_options;
|
||||||
if (!build_options.device_allocator()) {
|
if (!build_options.device_allocator()) {
|
||||||
@ -2153,14 +2199,15 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
|||||||
num_partitions = 1;
|
num_partitions = 1;
|
||||||
} else {
|
} else {
|
||||||
if (!build_options.has_device_assignment()) {
|
if (!build_options.has_device_assignment()) {
|
||||||
VLOG(2) << "PjRtClient::Compile using default device_assignment.";
|
VLOG(2) << "PjRtStreamExecutorClient::Compile using default "
|
||||||
|
"device_assignment.";
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
DeviceAssignment device_assignment,
|
DeviceAssignment device_assignment,
|
||||||
GetDefaultDeviceAssignment(build_options.num_replicas(),
|
GetDefaultDeviceAssignment(build_options.num_replicas(),
|
||||||
build_options.num_partitions()));
|
build_options.num_partitions()));
|
||||||
build_options.set_device_assignment(device_assignment);
|
build_options.set_device_assignment(device_assignment);
|
||||||
}
|
}
|
||||||
VLOG(2) << "PjRtClient::Compile device_assignment:\n"
|
VLOG(2) << "PjRtStreamExecutorClient::Compile device_assignment:\n"
|
||||||
<< build_options.device_assignment().ToString();
|
<< build_options.device_assignment().ToString();
|
||||||
num_replicas = build_options.device_assignment().replica_count();
|
num_replicas = build_options.device_assignment().replica_count();
|
||||||
num_partitions = build_options.device_assignment().computation_count();
|
num_partitions = build_options.device_assignment().computation_count();
|
||||||
@ -2234,7 +2281,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
|||||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||||
int device_id = (*device_assignment)(replica, partition);
|
int device_id = (*device_assignment)(replica, partition);
|
||||||
PjRtDevice* device = LookupDevice(*this, device_id);
|
TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
|
||||||
if (device->host_id() != host_id()) {
|
if (device->host_id() != host_id()) {
|
||||||
VLOG(3) << "Non-local device: " << device_id;
|
VLOG(3) << "Non-local device: " << device_id;
|
||||||
continue;
|
continue;
|
||||||
@ -2254,7 +2301,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
|||||||
|
|
||||||
if (build_options.device_ordinal() < 0) {
|
if (build_options.device_ordinal() < 0) {
|
||||||
build_options.set_device_ordinal(
|
build_options.set_device_ordinal(
|
||||||
addressable_devices.front()->local_device_state()->device_ordinal());
|
addressable_devices.front()->local_hardware_id());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/fingerprint.h"
|
#include "tensorflow/core/platform/fingerprint.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -67,16 +68,56 @@ class PjRtClient;
|
|||||||
|
|
||||||
class PjRtDevice {
|
class PjRtDevice {
|
||||||
public:
|
public:
|
||||||
explicit PjRtDevice(int id,
|
virtual ~PjRtDevice() {}
|
||||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
|
||||||
std::string device_kind, int host_id = 0)
|
// Return the client that owns this device.
|
||||||
|
virtual PjRtClient* client() const = 0;
|
||||||
|
|
||||||
|
// Whether client can issue command to this device.
|
||||||
|
virtual bool IsAddressable() const = 0;
|
||||||
|
|
||||||
|
// The ID of this device. IDs are unique among devices of this type
|
||||||
|
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
|
||||||
|
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
|
||||||
|
virtual int id() const = 0;
|
||||||
|
|
||||||
|
// The task ID of this device according to TpuTopology. This is not the same
|
||||||
|
// as PjRtClient::host_id() in a multi-task setting, where each client can see
|
||||||
|
// devices from all tasks, but only a subset of them are addressable and have
|
||||||
|
// the same task_id as the client.
|
||||||
|
virtual int host_id() const = 0;
|
||||||
|
|
||||||
|
// Opaque hardware ID, e.g., the CUDA device number, useful for identifying
|
||||||
|
// which GPU when interacting with non-JAX code. In general, not guaranteed to
|
||||||
|
// be dense, and -1 if undefined.
|
||||||
|
virtual int local_hardware_id() const = 0;
|
||||||
|
|
||||||
|
// A vendor-dependent string that uniquely identifies the kind of device,
|
||||||
|
// e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are
|
||||||
|
// compatible compilation.
|
||||||
|
virtual const std::string& device_kind() const = 0;
|
||||||
|
|
||||||
|
virtual std::string DebugString() const = 0;
|
||||||
|
|
||||||
|
// Transfer the given literal to the infeed queue.
|
||||||
|
virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
|
||||||
|
|
||||||
|
// Transfer and return a value of the given shape from the outfeed queue.
|
||||||
|
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class PjRtStreamExecutorDevice : public PjRtDevice {
|
||||||
|
public:
|
||||||
|
explicit PjRtStreamExecutorDevice(
|
||||||
|
int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
||||||
|
std::string device_kind, int host_id = 0)
|
||||||
: id_(id),
|
: id_(id),
|
||||||
local_device_id_(
|
device_ordinal_(
|
||||||
local_device_state ? local_device_state->device_ordinal() : -1),
|
local_device_state ? local_device_state->device_ordinal() : -1),
|
||||||
local_device_state_(std::move(local_device_state)),
|
local_device_state_(std::move(local_device_state)),
|
||||||
host_id_(host_id),
|
host_id_(host_id),
|
||||||
device_kind_(std::move(device_kind)) {}
|
device_kind_(std::move(device_kind)) {}
|
||||||
virtual ~PjRtDevice() {}
|
~PjRtStreamExecutorDevice() override {}
|
||||||
|
|
||||||
// Must set client exactly once.
|
// Must set client exactly once.
|
||||||
void SetClient(PjRtClient* client) {
|
void SetClient(PjRtClient* client) {
|
||||||
@ -84,14 +125,25 @@ class PjRtDevice {
|
|||||||
client_ = client;
|
client_ = client;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Task ID. This is always 0 on single-task setup.
|
||||||
|
int host_id() const override { return host_id_; }
|
||||||
|
|
||||||
|
// Return `platform_id` from client.
|
||||||
|
PjRtPlatformId platform_id() const;
|
||||||
|
|
||||||
|
// Return `platform_name` from client.
|
||||||
|
const std::string& platform_name() const;
|
||||||
|
|
||||||
|
PjRtClient* client() const override { return client_; }
|
||||||
|
|
||||||
// The ID of this device. IDs are unique among devices of this type
|
// The ID of this device. IDs are unique among devices of this type
|
||||||
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
|
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
|
||||||
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
|
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
|
||||||
int id() const { return id_; }
|
int id() const override { return id_; }
|
||||||
|
|
||||||
bool IsLocalDevice() const { return local_device_id_ != -1; }
|
bool IsAddressable() const override { return device_ordinal_ != -1; }
|
||||||
|
|
||||||
int local_device_id() const { return local_device_id_; }
|
int local_hardware_id() const override { return device_ordinal_; }
|
||||||
|
|
||||||
// If this is a device local to this host, returns a LocalDeviceState object
|
// If this is a device local to this host, returns a LocalDeviceState object
|
||||||
// that can be used to manipulate the device. Returns nullptr if the device is
|
// that can be used to manipulate the device. Returns nullptr if the device is
|
||||||
@ -105,32 +157,21 @@ class PjRtDevice {
|
|||||||
// is not local to this host.
|
// is not local to this host.
|
||||||
StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
|
StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
|
||||||
|
|
||||||
// The ID of this device's host. This is always 0 on single-host platforms.
|
|
||||||
int host_id() const { return host_id_; }
|
|
||||||
|
|
||||||
// Return `platform_id` from client.
|
|
||||||
PjRtPlatformId platform_id() const;
|
|
||||||
|
|
||||||
// Return `platform_name` from client.
|
|
||||||
const std::string& platform_name() const;
|
|
||||||
|
|
||||||
// A vendor-dependent string that uniquely identifies the kind of device.
|
// A vendor-dependent string that uniquely identifies the kind of device.
|
||||||
const std::string& device_kind() const { return device_kind_; }
|
const std::string& device_kind() const override { return device_kind_; }
|
||||||
|
|
||||||
virtual std::string DebugString() const;
|
std::string DebugString() const override;
|
||||||
|
|
||||||
PjRtClient* client() const { return client_; }
|
|
||||||
|
|
||||||
// Transfer the given literal to the infeed queue of the given localdevice.
|
// Transfer the given literal to the infeed queue of the given localdevice.
|
||||||
virtual Status TransferToInfeed(const LiteralSlice& literal) const;
|
Status TransferToInfeed(const LiteralSlice& literal) const override;
|
||||||
|
|
||||||
// Transfer and return a value of the given shape from the outfeed of the
|
// Transfer and return a value of the given shape from the outfeed of the
|
||||||
// given device.
|
// given device.
|
||||||
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const;
|
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const int id_;
|
const int id_;
|
||||||
const int local_device_id_; // -1 means not local.
|
const int device_ordinal_; // -1 means not local.
|
||||||
const std::unique_ptr<LocalDeviceState> local_device_state_;
|
const std::unique_ptr<LocalDeviceState> local_device_state_;
|
||||||
const int host_id_;
|
const int host_id_;
|
||||||
const std::string device_kind_;
|
const std::string device_kind_;
|
||||||
@ -178,86 +219,62 @@ class PjRtExecutable;
|
|||||||
// alive as long as any of the other runtime objects are alive.
|
// alive as long as any of the other runtime objects are alive.
|
||||||
class PjRtClient {
|
class PjRtClient {
|
||||||
public:
|
public:
|
||||||
// `allocator` may null, in which case the platform default allocator is used.
|
|
||||||
explicit PjRtClient(
|
|
||||||
std::string platform_name, LocalClient* client,
|
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
|
||||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
|
||||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
|
||||||
bool should_stage_host_to_device_transfers,
|
|
||||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
|
|
||||||
virtual ~PjRtClient() = default;
|
virtual ~PjRtClient() = default;
|
||||||
|
|
||||||
|
// TODO(zhangqiaorjc): Rename to task_id.
|
||||||
|
// Return the task id of this client. In single-task setting, always 0.
|
||||||
|
virtual int host_id() const = 0;
|
||||||
|
|
||||||
|
// Return the number of devices in the entire computation. In multi-headed
|
||||||
|
// client setting, some are addressable by this client, some are not. In a
|
||||||
|
// single-client setting, this is equal to the number of addressable devices.
|
||||||
|
virtual int device_count() const = 0;
|
||||||
|
|
||||||
|
// Return number of addressable devices. Addressable devices are those that
|
||||||
|
// the client can issue commands to.
|
||||||
|
virtual int addressable_device_count() const = 0;
|
||||||
|
|
||||||
|
// Return all devices in the entire computation, including addressable and
|
||||||
|
// non-addressable devices.
|
||||||
|
virtual absl::Span<PjRtDevice* const> devices() const = 0;
|
||||||
|
|
||||||
|
// TODO(zhangqiaorjc): Rename to addressable_devices.
|
||||||
|
// Return only addressable devices.
|
||||||
|
virtual absl::Span<PjRtDevice* const> local_devices() const = 0;
|
||||||
|
|
||||||
|
// Lookup any PjRtDevice for a given PjRtDevice::id().
|
||||||
|
virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
|
||||||
|
|
||||||
|
// Return an addressable PjRtDevice for a given
|
||||||
|
// PjRtDevice::local_hardware_id().
|
||||||
|
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
|
||||||
|
int local_hardware_id) const = 0;
|
||||||
|
|
||||||
|
// Return an ID that identifies the platform (CPU/GPU/TPU).
|
||||||
|
virtual PjRtPlatformId platform_id() const = 0;
|
||||||
|
|
||||||
|
// Returns a string that identifies the platform (CPU/GPU/TPU).
|
||||||
|
virtual const std::string& platform_name() const = 0;
|
||||||
|
|
||||||
|
// Return a device-specific default device assignment, e.g., GPU and TPU may
|
||||||
|
// be different.
|
||||||
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||||
int num_replicas, int num_partitions) const;
|
int num_replicas, int num_partitions) const = 0;
|
||||||
|
|
||||||
int device_count() const { return devices_.size(); }
|
|
||||||
int local_device_count() const { return local_devices_.size(); }
|
|
||||||
const std::vector<std::unique_ptr<PjRtDevice>>& devices() const {
|
|
||||||
return devices_;
|
|
||||||
}
|
|
||||||
const std::vector<PjRtDevice*>& local_devices() const {
|
|
||||||
return local_devices_;
|
|
||||||
}
|
|
||||||
const std::map<int, PjRtDevice*>& id_to_device() const {
|
|
||||||
return id_to_device_;
|
|
||||||
}
|
|
||||||
int host_id() const { return host_id_; }
|
|
||||||
PjRtPlatformId platform_id() const { return platform_id_; }
|
|
||||||
const std::string& platform_name() const { return platform_name_; }
|
|
||||||
|
|
||||||
LocalDeviceState& device_state(int device_ordinal) const {
|
|
||||||
return *local_devices_.at(device_ordinal)->local_device_state();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return a local PjRtDevice for a given `local_device_id`.
|
|
||||||
virtual StatusOr<PjRtDevice*> LookupLocalDevice(int local_device_id) const;
|
|
||||||
|
|
||||||
LocalClient* client() const { return client_; }
|
|
||||||
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
|
||||||
tensorflow::Allocator* host_memory_allocator() const {
|
|
||||||
return host_memory_allocator_.get();
|
|
||||||
}
|
|
||||||
bool should_stage_host_to_device_transfers() const {
|
|
||||||
return should_stage_host_to_device_transfers_;
|
|
||||||
}
|
|
||||||
|
|
||||||
gpu::GpuExecutableRunOptions* gpu_run_options() const {
|
|
||||||
return gpu_run_options_.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
|
||||||
return &h2d_transfer_pool_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Most platforms expect device-to-device transfers to be enqueued on the
|
|
||||||
// source d2d stream, but some platforms use the destination d2d stream. This
|
|
||||||
// function specifies which one the platform expects.
|
|
||||||
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
|
||||||
|
|
||||||
// Generates a unique fingerprint for `executable`.
|
|
||||||
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
|
||||||
const PjRtExecutable& executable) const {
|
|
||||||
return absl::optional<std::string>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a backend-specific HLO cost analysis visitor.
|
// Returns a backend-specific HLO cost analysis visitor.
|
||||||
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
|
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() = 0;
|
||||||
|
|
||||||
|
// Compile `computation` with given `options`.
|
||||||
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||||
const XlaComputation& computation, CompileOptions options);
|
const XlaComputation& computation, CompileOptions options) = 0;
|
||||||
|
|
||||||
|
// Generates a unique fingerprint for `executable`, may be absl::nullopt.
|
||||||
|
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||||
|
const PjRtExecutable& executable) const = 0;
|
||||||
|
|
||||||
// Creates a buffer on the device without initializing or copying any data.
|
// Creates a buffer on the device without initializing or copying any data.
|
||||||
// An optional `definition_event` may be speficied that can be used to
|
|
||||||
// ensure the buffer isn't referenced until some external mechanism has
|
|
||||||
// initialized the data.
|
|
||||||
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
|
|
||||||
// future backends and so callers should avoid wherever possible.
|
|
||||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||||
const Shape& shape, PjRtDevice* device);
|
const Shape& shape, PjRtDevice* device) = 0;
|
||||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
|
||||||
const Shape& shape, PjRtDevice* device,
|
|
||||||
std::shared_ptr<BufferSequencingEvent> definition_event);
|
|
||||||
|
|
||||||
// Describes the semantics the caller to BufferFromHostBuffer expects from the
|
// Describes the semantics the caller to BufferFromHostBuffer expects from the
|
||||||
// runtime, in a total order from most restrictive to least restrictive.
|
// runtime, in a total order from most restrictive to least restrictive.
|
||||||
@ -289,13 +306,13 @@ class PjRtClient {
|
|||||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
||||||
const void* data, const Shape& shape,
|
const void* data, const Shape& shape,
|
||||||
HostBufferSemantics host_buffer_semantics,
|
HostBufferSemantics host_buffer_semantics,
|
||||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device);
|
std::shared_ptr<void> buffer_reference, PjRtDevice* device) = 0;
|
||||||
|
|
||||||
// Note that literal must remain in scope until the transfer has completed, so
|
// Note that literal must remain in scope until the transfer has completed, so
|
||||||
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
||||||
// the return value before letting literal go out of scope.
|
// the return value before letting literal go out of scope.
|
||||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
||||||
const LiteralSlice& literal, PjRtDevice* device);
|
const LiteralSlice& literal, PjRtDevice* device) = 0;
|
||||||
|
|
||||||
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
||||||
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||||
@ -308,18 +325,140 @@ class PjRtClient {
|
|||||||
// buffers will become ready until *all* of the sends have completed.
|
// buffers will become ready until *all* of the sends have completed.
|
||||||
virtual void MakeCrossHostReceiveBuffers(
|
virtual void MakeCrossHostReceiveBuffers(
|
||||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||||
PjRtCrossHostRecvNotifier&& notifier);
|
PjRtCrossHostRecvNotifier&& notifier) = 0;
|
||||||
|
|
||||||
virtual StatusOr<ChannelHandle> CreateChannelHandle() {
|
// Create ChannelHandles for XLA send/recv.
|
||||||
|
virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
|
||||||
|
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
|
||||||
|
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class PjRtStreamExecutorClient : public PjRtClient {
|
||||||
|
public:
|
||||||
|
// `allocator` may null, in which case the platform default allocator is used.
|
||||||
|
explicit PjRtStreamExecutorClient(
|
||||||
|
std::string platform_name, LocalClient* client,
|
||||||
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
|
||||||
|
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||||
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||||
|
bool should_stage_host_to_device_transfers,
|
||||||
|
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
|
||||||
|
~PjRtStreamExecutorClient() override = default;
|
||||||
|
|
||||||
|
int host_id() const override { return host_id_; }
|
||||||
|
|
||||||
|
int device_count() const override { return devices_.size(); }
|
||||||
|
int addressable_device_count() const override {
|
||||||
|
return local_devices_.size();
|
||||||
|
}
|
||||||
|
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
|
||||||
|
absl::Span<PjRtDevice* const> local_devices() const override {
|
||||||
|
return local_devices_;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
|
||||||
|
auto it = id_to_device_.find(device_id);
|
||||||
|
if (it != id_to_device_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
return InvalidArgument("No matching device found for device_id %d",
|
||||||
|
device_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<PjRtDevice*> LookupAddressableDevice(
|
||||||
|
int local_hardware_id) const override;
|
||||||
|
|
||||||
|
PjRtPlatformId platform_id() const override { return platform_id_; }
|
||||||
|
const std::string& platform_name() const override { return platform_name_; }
|
||||||
|
|
||||||
|
// Most platforms expect device-to-device transfers to be enqueued on the
|
||||||
|
// source d2d stream, but some platforms use the destination d2d stream. This
|
||||||
|
// function specifies which one the platform expects.
|
||||||
|
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
||||||
|
|
||||||
|
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||||
|
int num_replicas, int num_partitions) const override;
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||||
|
const XlaComputation& computation, CompileOptions options) override;
|
||||||
|
|
||||||
|
// Generates a unique fingerprint for `executable`.
|
||||||
|
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||||
|
const PjRtExecutable& executable) const override {
|
||||||
|
return absl::optional<std::string>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a backend-specific HLO cost analysis visitor.
|
||||||
|
std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() override;
|
||||||
|
|
||||||
|
// Creates a buffer on the device without initializing or copying any data.
|
||||||
|
// An optional `definition_event` may be speficied that can be used to
|
||||||
|
// ensure the buffer isn't referenced until some external mechanism has
|
||||||
|
// initialized the data.
|
||||||
|
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
|
||||||
|
// future backends and so callers should avoid wherever possible.
|
||||||
|
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||||
|
const Shape& shape, PjRtDevice* device) override;
|
||||||
|
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||||
|
const Shape& shape, PjRtDevice* device,
|
||||||
|
std::shared_ptr<BufferSequencingEvent> definition_event);
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
||||||
|
const void* data, const Shape& shape,
|
||||||
|
HostBufferSemantics host_buffer_semantics,
|
||||||
|
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
|
||||||
|
|
||||||
|
// Note that literal must remain in scope until the transfer has completed, so
|
||||||
|
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
||||||
|
// the return value before letting literal go out of scope.
|
||||||
|
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
||||||
|
const LiteralSlice& literal, PjRtDevice* device) override;
|
||||||
|
|
||||||
|
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
||||||
|
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||||
|
// shapes, with identical layouts, corresponding to the buffers that will be
|
||||||
|
// sent. When resources for the transfer are available, notifier will be
|
||||||
|
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
|
||||||
|
// shape in `shapes`. Each struct contains a buffer that will contain the
|
||||||
|
// received value, and an opaque string that should be transmitted to the
|
||||||
|
// sending host and used in a call to CopyToRemoteDevice. None of the recv
|
||||||
|
// buffers will become ready until *all* of the sends have completed.
|
||||||
|
void MakeCrossHostReceiveBuffers(
|
||||||
|
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||||
|
PjRtCrossHostRecvNotifier&& notifier) override;
|
||||||
|
|
||||||
|
StatusOr<ChannelHandle> CreateChannelHandle() override {
|
||||||
return client()->CreateChannelHandle();
|
return client()->CreateChannelHandle();
|
||||||
}
|
}
|
||||||
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
|
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
|
||||||
return client()->CreateDeviceToHostChannelHandle();
|
return client()->CreateDeviceToHostChannelHandle();
|
||||||
}
|
}
|
||||||
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
|
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
|
||||||
return client()->CreateHostToDeviceChannelHandle();
|
return client()->CreateHostToDeviceChannelHandle();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LocalDeviceState& device_state(int device_ordinal) const {
|
||||||
|
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
|
||||||
|
local_devices_.at(device_ordinal))
|
||||||
|
->local_device_state();
|
||||||
|
}
|
||||||
|
LocalClient* client() const { return client_; }
|
||||||
|
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
||||||
|
tensorflow::Allocator* host_memory_allocator() const {
|
||||||
|
return host_memory_allocator_.get();
|
||||||
|
}
|
||||||
|
bool should_stage_host_to_device_transfers() const {
|
||||||
|
return should_stage_host_to_device_transfers_;
|
||||||
|
}
|
||||||
|
|
||||||
|
gpu::GpuExecutableRunOptions* gpu_run_options() const {
|
||||||
|
return gpu_run_options_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
||||||
|
return &h2d_transfer_pool_;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
friend class PjRtBuffer;
|
friend class PjRtBuffer;
|
||||||
virtual void EnqueueCrossHostReceive(
|
virtual void EnqueueCrossHostReceive(
|
||||||
@ -342,7 +481,9 @@ class PjRtClient {
|
|||||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
||||||
|
|
||||||
// Includes all devices, including non-local devices on multi-host platforms.
|
// Includes all devices, including non-local devices on multi-host platforms.
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices_;
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
|
||||||
|
// Pointers to `owned_devices_`.
|
||||||
|
std::vector<PjRtDevice*> devices_;
|
||||||
// Maps Device::id() to the corresponding Device. Includes all devices.
|
// Maps Device::id() to the corresponding Device. Includes all devices.
|
||||||
std::map<int, PjRtDevice*> id_to_device_;
|
std::map<int, PjRtDevice*> id_to_device_;
|
||||||
// Local devices indexed by local device ordinal.
|
// Local devices indexed by local device ordinal.
|
||||||
@ -509,7 +650,7 @@ class PjRtBuffer {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
friend class PjRtBuffer;
|
friend class PjRtBuffer;
|
||||||
friend class PjRtClient;
|
friend class PjRtStreamExecutorClient;
|
||||||
|
|
||||||
// Helper struct that makes it possible to move a ScopedHold through a
|
// Helper struct that makes it possible to move a ScopedHold through a
|
||||||
// closure.
|
// closure.
|
||||||
@ -769,7 +910,7 @@ class PjRtExecutable {
|
|||||||
virtual PjRtClient* client() const = 0;
|
virtual PjRtClient* client() const = 0;
|
||||||
|
|
||||||
// Unique name for this executable, e.g., HloModule name.
|
// Unique name for this executable, e.g., HloModule name.
|
||||||
virtual const string& name() const = 0;
|
virtual const std::string& name() const = 0;
|
||||||
|
|
||||||
virtual int num_replicas() const = 0;
|
virtual int num_replicas() const = 0;
|
||||||
|
|
||||||
@ -791,6 +932,7 @@ class PjRtExecutable {
|
|||||||
virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
|
virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
|
||||||
const = 0;
|
const = 0;
|
||||||
|
|
||||||
|
// An addressable_device is one which the client can issue commands to.
|
||||||
// addressable_devices()[i] is the Device to which
|
// addressable_devices()[i] is the Device to which
|
||||||
// addressable_device_logical_ids()[i] is assigned.
|
// addressable_device_logical_ids()[i] is assigned.
|
||||||
virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
|
virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
|
||||||
@ -833,13 +975,14 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
|||||||
bool parameter_is_tupled_arguments,
|
bool parameter_is_tupled_arguments,
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||||
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
||||||
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client);
|
std::vector<PjRtDevice*> addressable_devices,
|
||||||
|
PjRtStreamExecutorClient* client);
|
||||||
|
|
||||||
~PjRtStreamExecutorExecutable() override = default;
|
~PjRtStreamExecutorExecutable() override = default;
|
||||||
|
|
||||||
PjRtClient* client() const override { return client_; }
|
PjRtStreamExecutorClient* client() const override { return client_; }
|
||||||
|
|
||||||
const string& name() const override;
|
const std::string& name() const override;
|
||||||
|
|
||||||
int num_replicas() const override {
|
int num_replicas() const override {
|
||||||
return executables_[0]->build_options().num_replicas();
|
return executables_[0]->build_options().num_replicas();
|
||||||
@ -898,7 +1041,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class PjRtClient;
|
friend class PjRtStreamExecutorClient;
|
||||||
// Initializes information about which arguments to which executables must be
|
// Initializes information about which arguments to which executables must be
|
||||||
// donated due to aliases that were specified by the computation.
|
// donated due to aliases that were specified by the computation.
|
||||||
Status SetUpDonation(bool tuple_inputs);
|
Status SetUpDonation(bool tuple_inputs);
|
||||||
@ -933,7 +1076,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
|||||||
// Create shared pointers so we can free them after the execution: with
|
// Create shared pointers so we can free them after the execution: with
|
||||||
// asynchronous execution, the process being executed can outlive the
|
// asynchronous execution, the process being executed can outlive the
|
||||||
// executable itself.
|
// executable itself.
|
||||||
PjRtClient* const client_;
|
PjRtStreamExecutorClient* const client_;
|
||||||
// One executable per partition.
|
// One executable per partition.
|
||||||
std::vector<std::shared_ptr<LocalExecutable>> executables_;
|
std::vector<std::shared_ptr<LocalExecutable>> executables_;
|
||||||
// Per-executable set of parameters that have any aliased buffers and thus
|
// Per-executable set of parameters that have any aliased buffers and thus
|
||||||
|
@ -94,10 +94,11 @@ Status TpuDeviceState::ThenMemcpyDeviceToDevice(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
class PjRtTpuClient : public PjRtClient {
|
class PjRtTpuClient : public PjRtStreamExecutorClient {
|
||||||
public:
|
public:
|
||||||
PjRtTpuClient(LocalClient* client,
|
PjRtTpuClient(LocalClient* client,
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id);
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
|
||||||
|
int host_id);
|
||||||
|
|
||||||
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||||
int num_replicas, int num_partitions) const override;
|
int num_replicas, int num_partitions) const override;
|
||||||
@ -108,14 +109,14 @@ class PjRtTpuClient : public PjRtClient {
|
|||||||
const PjRtExecutable& executable) const override;
|
const PjRtExecutable& executable) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
PjRtTpuClient::PjRtTpuClient(LocalClient* client,
|
PjRtTpuClient::PjRtTpuClient(
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices,
|
LocalClient* client,
|
||||||
int host_id)
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id)
|
||||||
: PjRtClient(kTpuName, client, std::move(devices), host_id,
|
: PjRtStreamExecutorClient(kTpuName, client, std::move(devices), host_id,
|
||||||
/*allocator=*/nullptr,
|
/*allocator=*/nullptr,
|
||||||
/*host_memory_allocator=*/nullptr,
|
/*host_memory_allocator=*/nullptr,
|
||||||
/*should_stage_host_to_device_transfers=*/false,
|
/*should_stage_host_to_device_transfers=*/false,
|
||||||
/*gpu_run_options=*/nullptr) {}
|
/*gpu_run_options=*/nullptr) {}
|
||||||
|
|
||||||
StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
|
StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
|
||||||
int num_replicas, int num_partitions) const {
|
int num_replicas, int num_partitions) const {
|
||||||
@ -128,7 +129,8 @@ StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
|
|||||||
num_partitions);
|
num_partitions);
|
||||||
}
|
}
|
||||||
// Fallback to default global device assignment if we can't run locally.
|
// Fallback to default global device assignment if we can't run locally.
|
||||||
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
|
return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
|
||||||
|
num_partitions);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
|
StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
|
||||||
@ -152,10 +154,10 @@ StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
|
|||||||
return absl::optional<std::string>(tpu_executable->fingerprint());
|
return absl::optional<std::string>(tpu_executable->fingerprint());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices(
|
StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
|
||||||
LocalClient* client,
|
LocalClient* client,
|
||||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||||
tf_tpu::TpuTopologyExternal topology =
|
tf_tpu::TpuTopologyExternal topology =
|
||||||
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
|
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
|
||||||
|
|
||||||
|
@ -26,14 +26,14 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
class PjRtTpuDevice : public PjRtDevice {
|
class PjRtTpuDevice : public PjRtStreamExecutorDevice {
|
||||||
public:
|
public:
|
||||||
PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,
|
PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,
|
||||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
std::unique_ptr<LocalDeviceState> local_device_state,
|
||||||
int host_id, const std::array<int, 3>& coords,
|
int host_id, const std::array<int, 3>& coords,
|
||||||
std::string device_kind)
|
std::string device_kind)
|
||||||
: PjRtDevice(core.Id(), std::move(local_device_state),
|
: PjRtStreamExecutorDevice(core.Id(), std::move(local_device_state),
|
||||||
std::move(device_kind), host_id),
|
std::move(device_kind), host_id),
|
||||||
core_(core),
|
core_(core),
|
||||||
coords_(coords) {}
|
coords_(coords) {}
|
||||||
|
|
||||||
|
@ -97,13 +97,13 @@ cc_library(
|
|||||||
name = "types",
|
name = "types",
|
||||||
srcs = ["types.cc"],
|
srcs = ["types.cc"],
|
||||||
hdrs = ["types.h"],
|
hdrs = ["types.h"],
|
||||||
|
compatible_with = [],
|
||||||
copts = [
|
copts = [
|
||||||
"-fexceptions",
|
"-fexceptions",
|
||||||
"-fno-strict-aliasing",
|
"-fno-strict-aliasing",
|
||||||
],
|
],
|
||||||
features = ["-use_header_modules"],
|
features = ["-use_header_modules"],
|
||||||
deps = [
|
deps = [
|
||||||
":bfloat16",
|
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status",
|
"//tensorflow/compiler/xla:status",
|
||||||
@ -113,6 +113,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/python:bfloat16_lib",
|
||||||
"//third_party/py/numpy:headers",
|
"//third_party/py/numpy:headers",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
@ -158,42 +159,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "bfloat16",
|
|
||||||
srcs = ["bfloat16.cc"],
|
|
||||||
hdrs = ["bfloat16.h"],
|
|
||||||
copts = [
|
|
||||||
"-fexceptions",
|
|
||||||
"-fno-strict-aliasing",
|
|
||||||
],
|
|
||||||
features = ["-use_header_modules"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/compiler/xla:statusor",
|
|
||||||
"//tensorflow/compiler/xla:types",
|
|
||||||
"//tensorflow/compiler/xla:util",
|
|
||||||
"//tensorflow/core/platform:bfloat16",
|
|
||||||
"//tensorflow/core/platform:logging",
|
|
||||||
"//third_party/py/numpy:headers",
|
|
||||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@pybind11",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "bfloat16_test",
|
|
||||||
srcs = ["bfloat16_test.py"],
|
|
||||||
main = "bfloat16_test.py",
|
|
||||||
python_version = "PY3",
|
|
||||||
tags = ["no_oss"],
|
|
||||||
deps = [
|
|
||||||
":xla_client",
|
|
||||||
":xla_extension",
|
|
||||||
"@absl_py//absl/testing:absltest",
|
|
||||||
"@absl_py//absl/testing:parameterized",
|
|
||||||
] + xla_py_test_deps(),
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "py_client",
|
name = "py_client",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -206,6 +171,7 @@ cc_library(
|
|||||||
"py_client.h",
|
"py_client.h",
|
||||||
"py_executable.h",
|
"py_executable.h",
|
||||||
],
|
],
|
||||||
|
compatible_with = [],
|
||||||
copts = [
|
copts = [
|
||||||
"-fexceptions",
|
"-fexceptions",
|
||||||
"-fno-strict-aliasing",
|
"-fno-strict-aliasing",
|
||||||
@ -232,6 +198,7 @@ cc_library(
|
|||||||
name = "dlpack",
|
name = "dlpack",
|
||||||
srcs = ["dlpack.cc"],
|
srcs = ["dlpack.cc"],
|
||||||
hdrs = ["dlpack.h"],
|
hdrs = ["dlpack.h"],
|
||||||
|
compatible_with = [],
|
||||||
copts = [
|
copts = [
|
||||||
"-fexceptions",
|
"-fexceptions",
|
||||||
"-fno-strict-aliasing",
|
"-fno-strict-aliasing",
|
||||||
@ -263,6 +230,7 @@ cc_library(
|
|||||||
name = "jax_jit",
|
name = "jax_jit",
|
||||||
srcs = ["jax_jit.cc"],
|
srcs = ["jax_jit.cc"],
|
||||||
hdrs = ["jax_jit.h"],
|
hdrs = ["jax_jit.h"],
|
||||||
|
compatible_with = [],
|
||||||
copts = [
|
copts = [
|
||||||
"-fexceptions",
|
"-fexceptions",
|
||||||
"-fno-strict-aliasing",
|
"-fno-strict-aliasing",
|
||||||
@ -292,6 +260,7 @@ cc_library(
|
|||||||
name = "ops",
|
name = "ops",
|
||||||
srcs = ["ops.cc"],
|
srcs = ["ops.cc"],
|
||||||
hdrs = ["ops.h"],
|
hdrs = ["ops.h"],
|
||||||
|
compatible_with = [],
|
||||||
copts = [
|
copts = [
|
||||||
"-fexceptions",
|
"-fexceptions",
|
||||||
"-fno-strict-aliasing",
|
"-fno-strict-aliasing",
|
||||||
@ -356,6 +325,7 @@ cc_library(
|
|||||||
name = "outfeed_receiver_py",
|
name = "outfeed_receiver_py",
|
||||||
srcs = ["outfeed_receiver_py.cc"],
|
srcs = ["outfeed_receiver_py.cc"],
|
||||||
hdrs = ["outfeed_receiver_py.h"],
|
hdrs = ["outfeed_receiver_py.h"],
|
||||||
|
compatible_with = [],
|
||||||
copts = [
|
copts = [
|
||||||
"-fexceptions",
|
"-fexceptions",
|
||||||
"-fno-strict-aliasing",
|
"-fno-strict-aliasing",
|
||||||
@ -379,12 +349,14 @@ cc_library(
|
|||||||
name = "pytree",
|
name = "pytree",
|
||||||
srcs = ["pytree.cc"],
|
srcs = ["pytree.cc"],
|
||||||
hdrs = ["pytree.h"],
|
hdrs = ["pytree.h"],
|
||||||
|
compatible_with = [],
|
||||||
copts = [
|
copts = [
|
||||||
"-fexceptions",
|
"-fexceptions",
|
||||||
"-fno-strict-aliasing",
|
"-fno-strict-aliasing",
|
||||||
],
|
],
|
||||||
features = ["-use_header_modules"],
|
features = ["-use_header_modules"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":types",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/hash",
|
"@com_google_absl//absl/hash",
|
||||||
@ -435,6 +407,7 @@ cc_library(
|
|||||||
name = "xla_compiler",
|
name = "xla_compiler",
|
||||||
srcs = ["xla_compiler.cc"],
|
srcs = ["xla_compiler.cc"],
|
||||||
hdrs = ["xla_compiler.h"],
|
hdrs = ["xla_compiler.h"],
|
||||||
|
compatible_with = [],
|
||||||
copts = [
|
copts = [
|
||||||
"-fexceptions",
|
"-fexceptions",
|
||||||
"-fno-strict-aliasing",
|
"-fno-strict-aliasing",
|
||||||
@ -481,7 +454,6 @@ pybind_extension(
|
|||||||
features = ["-use_header_modules"],
|
features = ["-use_header_modules"],
|
||||||
module_name = "xla_extension",
|
module_name = "xla_extension",
|
||||||
deps = [
|
deps = [
|
||||||
":bfloat16",
|
|
||||||
":dlpack",
|
":dlpack",
|
||||||
":jax_jit",
|
":jax_jit",
|
||||||
":ops",
|
":ops",
|
||||||
@ -534,6 +506,7 @@ pybind_extension(
|
|||||||
# without any TF dependencies as "jaxlib" on Pypi, and "jaxlib" does
|
# without any TF dependencies as "jaxlib" on Pypi, and "jaxlib" does
|
||||||
# not require Tensorflow.
|
# not require Tensorflow.
|
||||||
"//tensorflow/core:lib_internal_impl", # buildcleaner: keep
|
"//tensorflow/core:lib_internal_impl", # buildcleaner: keep
|
||||||
|
"//tensorflow/python:bfloat16_lib",
|
||||||
"//tensorflow/stream_executor:device_memory_allocator",
|
"//tensorflow/stream_executor:device_memory_allocator",
|
||||||
"//tensorflow/stream_executor:platform",
|
"//tensorflow/stream_executor:platform",
|
||||||
] + select({
|
] + select({
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,440 +0,0 @@
|
|||||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Test cases for the bfloat16 Python type."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import collections
|
|
||||||
import copy
|
|
||||||
import itertools
|
|
||||||
import math
|
|
||||||
|
|
||||||
from absl.testing import absltest
|
|
||||||
from absl.testing import parameterized
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from tensorflow.compiler.xla.python import xla_client
|
|
||||||
|
|
||||||
bfloat16 = xla_client.bfloat16
|
|
||||||
|
|
||||||
|
|
||||||
def numpy_assert_allclose(a, b, **kwargs):
|
|
||||||
a = a.astype(np.float32) if a.dtype == bfloat16 else a
|
|
||||||
b = b.astype(np.float32) if b.dtype == bfloat16 else b
|
|
||||||
return np.testing.assert_allclose(a, b, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
epsilon = float.fromhex("1.0p-7")
|
|
||||||
|
|
||||||
# Values that should round trip exactly to float and back.
|
|
||||||
FLOAT_VALUES = [
|
|
||||||
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
|
|
||||||
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
|
|
||||||
float("inf"),
|
|
||||||
float("-inf"),
|
|
||||||
float("nan")
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Bfloat16Test(parameterized.TestCase):
|
|
||||||
"""Tests the non-numpy Python methods of the bfloat16 type."""
|
|
||||||
|
|
||||||
def testRoundTripToFloat(self):
|
|
||||||
for v in FLOAT_VALUES:
|
|
||||||
np.testing.assert_equal(v, float(bfloat16(v)))
|
|
||||||
|
|
||||||
def testRoundTripNumpyTypes(self):
|
|
||||||
for dtype in [np.float16, np.float32, np.float64]:
|
|
||||||
np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75))))
|
|
||||||
np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5))))
|
|
||||||
np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype))))
|
|
||||||
np.testing.assert_equal(
|
|
||||||
np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype)))
|
|
||||||
|
|
||||||
def testRoundTripToInt(self):
|
|
||||||
for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]:
|
|
||||||
self.assertEqual(v, int(bfloat16(v)))
|
|
||||||
|
|
||||||
# pylint: disable=g-complex-comprehension
|
|
||||||
@parameterized.named_parameters(({
|
|
||||||
"testcase_name": "_" + dtype.__name__,
|
|
||||||
"dtype": dtype
|
|
||||||
} for dtype in [bfloat16, np.float16, np.float32, np.float64]))
|
|
||||||
def testRoundTripToNumpy(self, dtype):
|
|
||||||
for v in FLOAT_VALUES:
|
|
||||||
np.testing.assert_equal(v, bfloat16(dtype(v)))
|
|
||||||
np.testing.assert_equal(v, dtype(bfloat16(dtype(v))))
|
|
||||||
np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype))))
|
|
||||||
if dtype != bfloat16:
|
|
||||||
np.testing.assert_equal(
|
|
||||||
np.array(FLOAT_VALUES, dtype),
|
|
||||||
bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype))
|
|
||||||
|
|
||||||
def testStr(self):
|
|
||||||
self.assertEqual("0", str(bfloat16(0.0)))
|
|
||||||
self.assertEqual("1", str(bfloat16(1.0)))
|
|
||||||
self.assertEqual("-3.5", str(bfloat16(-3.5)))
|
|
||||||
self.assertEqual("0.0078125", str(bfloat16(float.fromhex("1.0p-7"))))
|
|
||||||
self.assertEqual("inf", str(bfloat16(float("inf"))))
|
|
||||||
self.assertEqual("-inf", str(bfloat16(float("-inf"))))
|
|
||||||
self.assertEqual("nan", str(bfloat16(float("nan"))))
|
|
||||||
|
|
||||||
def testRepr(self):
|
|
||||||
self.assertEqual("0", repr(bfloat16(0)))
|
|
||||||
self.assertEqual("1", repr(bfloat16(1)))
|
|
||||||
self.assertEqual("-3.5", repr(bfloat16(-3.5)))
|
|
||||||
self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7"))))
|
|
||||||
self.assertEqual("inf", repr(bfloat16(float("inf"))))
|
|
||||||
self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
|
|
||||||
self.assertEqual("nan", repr(bfloat16(float("nan"))))
|
|
||||||
|
|
||||||
def testHash(self):
|
|
||||||
self.assertEqual(0, hash(bfloat16(0.0)))
|
|
||||||
self.assertEqual(0x3f80, hash(bfloat16(1.0)))
|
|
||||||
self.assertEqual(0x7fc0, hash(bfloat16(float("nan"))))
|
|
||||||
|
|
||||||
# Tests for Python operations
|
|
||||||
def testNegate(self):
|
|
||||||
for v in FLOAT_VALUES:
|
|
||||||
np.testing.assert_equal(-v, float(-bfloat16(v)))
|
|
||||||
|
|
||||||
def testAdd(self):
|
|
||||||
np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0)))
|
|
||||||
np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0)))
|
|
||||||
np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1)))
|
|
||||||
np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5)))
|
|
||||||
np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
|
|
||||||
np.testing.assert_equal(
|
|
||||||
float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25)))
|
|
||||||
np.testing.assert_equal(
|
|
||||||
float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25)))
|
|
||||||
self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan")))))
|
|
||||||
|
|
||||||
# Test type promotion against Numpy scalar values.
|
|
||||||
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float16(2.25)))
|
|
||||||
self.assertEqual(np.float32, type(np.float16(3.5) + bfloat16(2.25)))
|
|
||||||
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float32(2.25)))
|
|
||||||
self.assertEqual(np.float32, type(np.float32(3.5) + bfloat16(2.25)))
|
|
||||||
self.assertEqual(np.float64, type(bfloat16(3.5) + np.float64(2.25)))
|
|
||||||
self.assertEqual(np.float64, type(np.float64(3.5) + bfloat16(2.25)))
|
|
||||||
self.assertEqual(np.float64, type(bfloat16(3.5) + float(2.25)))
|
|
||||||
self.assertEqual(np.float64, type(float(3.5) + bfloat16(2.25)))
|
|
||||||
self.assertEqual(np.float32,
|
|
||||||
type(bfloat16(3.5) + np.array(2.25, np.float32)))
|
|
||||||
self.assertEqual(np.float32,
|
|
||||||
type(np.array(3.5, np.float32) + bfloat16(2.25)))
|
|
||||||
|
|
||||||
def testSub(self):
|
|
||||||
np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0)))
|
|
||||||
np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0)))
|
|
||||||
np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1)))
|
|
||||||
np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5)))
|
|
||||||
np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
|
|
||||||
np.testing.assert_equal(
|
|
||||||
float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf"))))
|
|
||||||
np.testing.assert_equal(
|
|
||||||
float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf"))))
|
|
||||||
self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan")))))
|
|
||||||
|
|
||||||
def testMul(self):
|
|
||||||
np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0)))
|
|
||||||
np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0)))
|
|
||||||
np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1)))
|
|
||||||
np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
|
|
||||||
np.testing.assert_equal(
|
|
||||||
float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25)))
|
|
||||||
np.testing.assert_equal(
|
|
||||||
float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25)))
|
|
||||||
self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan")))))
|
|
||||||
|
|
||||||
def testDiv(self):
|
|
||||||
self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0))))
|
|
||||||
np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0)))
|
|
||||||
np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1)))
|
|
||||||
np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
|
|
||||||
np.testing.assert_equal(
|
|
||||||
float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25)))
|
|
||||||
np.testing.assert_equal(
|
|
||||||
float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25)))
|
|
||||||
self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
|
|
||||||
|
|
||||||
def testLess(self):
|
|
||||||
for v in FLOAT_VALUES:
|
|
||||||
for w in FLOAT_VALUES:
|
|
||||||
self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
|
|
||||||
|
|
||||||
def testLessEqual(self):
|
|
||||||
for v in FLOAT_VALUES:
|
|
||||||
for w in FLOAT_VALUES:
|
|
||||||
self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
|
|
||||||
|
|
||||||
def testGreater(self):
|
|
||||||
for v in FLOAT_VALUES:
|
|
||||||
for w in FLOAT_VALUES:
|
|
||||||
self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
|
|
||||||
|
|
||||||
def testGreaterEqual(self):
|
|
||||||
for v in FLOAT_VALUES:
|
|
||||||
for w in FLOAT_VALUES:
|
|
||||||
self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
|
|
||||||
|
|
||||||
def testEqual(self):
|
|
||||||
for v in FLOAT_VALUES:
|
|
||||||
for w in FLOAT_VALUES:
|
|
||||||
self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
|
|
||||||
|
|
||||||
def testNotEqual(self):
|
|
||||||
for v in FLOAT_VALUES:
|
|
||||||
for w in FLOAT_VALUES:
|
|
||||||
self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
|
|
||||||
|
|
||||||
def testNan(self):
|
|
||||||
a = np.isnan(bfloat16(float("nan")))
|
|
||||||
self.assertTrue(a)
|
|
||||||
numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
|
|
||||||
|
|
||||||
a = np.array([bfloat16(1.34375),
|
|
||||||
bfloat16(1.4375),
|
|
||||||
bfloat16(float("nan"))],
|
|
||||||
dtype=bfloat16)
|
|
||||||
b = np.array(
|
|
||||||
[bfloat16(1.3359375),
|
|
||||||
bfloat16(1.4375),
|
|
||||||
bfloat16(float("nan"))],
|
|
||||||
dtype=bfloat16)
|
|
||||||
numpy_assert_allclose(
|
|
||||||
a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
|
|
||||||
|
|
||||||
def testSort(self):
|
|
||||||
values_to_sort = np.float32(FLOAT_VALUES)
|
|
||||||
sorted_f32 = np.sort(values_to_sort)
|
|
||||||
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
|
|
||||||
np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
|
|
||||||
|
|
||||||
|
|
||||||
BinaryOp = collections.namedtuple("BinaryOp", ["op"])
|
|
||||||
|
|
||||||
UNARY_UFUNCS = [
|
|
||||||
np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign,
|
|
||||||
np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p,
|
|
||||||
np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan,
|
|
||||||
np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh,
|
|
||||||
np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc
|
|
||||||
]
|
|
||||||
|
|
||||||
BINARY_UFUNCS = [
|
|
||||||
np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2,
|
|
||||||
np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2,
|
|
||||||
np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign
|
|
||||||
]
|
|
||||||
|
|
||||||
BINARY_PREDICATE_UFUNCS = [
|
|
||||||
np.equal, np.not_equal, np.less, np.greater, np.less_equal,
|
|
||||||
np.greater_equal, np.logical_and, np.logical_or, np.logical_xor
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Bfloat16NumPyTest(parameterized.TestCase):
|
|
||||||
"""Tests the NumPy integration of the bfloat16 type."""
|
|
||||||
|
|
||||||
def testDtype(self):
|
|
||||||
self.assertEqual(bfloat16, np.dtype(bfloat16))
|
|
||||||
|
|
||||||
def testDeepCopyDoesNotAlterHash(self):
|
|
||||||
# For context, see https://github.com/google/jax/issues/4651. If the hash
|
|
||||||
# value of the type descriptor is not initialized correctly, a deep copy
|
|
||||||
# can change the type hash.
|
|
||||||
dtype = np.dtype(bfloat16)
|
|
||||||
h = hash(dtype)
|
|
||||||
_ = copy.deepcopy(dtype)
|
|
||||||
self.assertEqual(h, hash(dtype))
|
|
||||||
|
|
||||||
def testArray(self):
|
|
||||||
x = np.array([[1, 2, 3]], dtype=bfloat16)
|
|
||||||
self.assertEqual(bfloat16, x.dtype)
|
|
||||||
self.assertEqual("[[1 2 3]]", str(x))
|
|
||||||
np.testing.assert_equal(x, x)
|
|
||||||
numpy_assert_allclose(x, x)
|
|
||||||
self.assertTrue((x == x).all())
|
|
||||||
|
|
||||||
def testComparisons(self):
|
|
||||||
x = np.array([401408, 7, -32], dtype=np.float32)
|
|
||||||
bx = x.astype(bfloat16)
|
|
||||||
y = np.array([82432, 7, 0], dtype=np.float32)
|
|
||||||
by = y.astype(bfloat16)
|
|
||||||
np.testing.assert_equal(x == y, bx == by)
|
|
||||||
np.testing.assert_equal(x != y, bx != by)
|
|
||||||
np.testing.assert_equal(x < y, bx < by)
|
|
||||||
np.testing.assert_equal(x > y, bx > by)
|
|
||||||
np.testing.assert_equal(x <= y, bx <= by)
|
|
||||||
np.testing.assert_equal(x >= y, bx >= by)
|
|
||||||
|
|
||||||
def testEqual2(self):
|
|
||||||
a = np.array([401408], bfloat16)
|
|
||||||
b = np.array([82432], bfloat16)
|
|
||||||
self.assertFalse(a.__eq__(b))
|
|
||||||
|
|
||||||
def testCasts(self):
|
|
||||||
for dtype in [
|
|
||||||
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
|
|
||||||
np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
|
|
||||||
np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
|
|
||||||
]:
|
|
||||||
x = np.array([[1, 2, 3]], dtype=dtype)
|
|
||||||
y = x.astype(bfloat16)
|
|
||||||
z = y.astype(dtype)
|
|
||||||
self.assertTrue(np.all(x == y))
|
|
||||||
self.assertEqual(bfloat16, y.dtype)
|
|
||||||
self.assertTrue(np.all(x == z))
|
|
||||||
self.assertEqual(dtype, z.dtype)
|
|
||||||
|
|
||||||
def testConformNumpyComplex(self):
|
|
||||||
for dtype in [np.complex64, np.complex128]:
|
|
||||||
x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype)
|
|
||||||
y_np = x.astype(np.float32)
|
|
||||||
y_tf = x.astype(bfloat16)
|
|
||||||
numpy_assert_allclose(y_np, y_tf, atol=2e-2)
|
|
||||||
|
|
||||||
z_np = y_np.astype(dtype)
|
|
||||||
z_tf = y_tf.astype(dtype)
|
|
||||||
numpy_assert_allclose(z_np, z_tf, atol=2e-2)
|
|
||||||
|
|
||||||
def testArange(self):
|
|
||||||
np.testing.assert_equal(
|
|
||||||
np.arange(100, dtype=np.float32).astype(bfloat16),
|
|
||||||
np.arange(100, dtype=bfloat16))
|
|
||||||
np.testing.assert_equal(
|
|
||||||
np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
|
|
||||||
np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
|
|
||||||
np.testing.assert_equal(
|
|
||||||
np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
|
|
||||||
np.arange(-0., -7., -0.25, dtype=bfloat16))
|
|
||||||
np.testing.assert_equal(
|
|
||||||
np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
|
|
||||||
np.arange(-16384., 16384., 64., dtype=bfloat16))
|
|
||||||
|
|
||||||
# pylint: disable=g-complex-comprehension
|
|
||||||
@parameterized.named_parameters(({
|
|
||||||
"testcase_name": "_" + op.__name__,
|
|
||||||
"op": op
|
|
||||||
} for op in UNARY_UFUNCS))
|
|
||||||
def testUnaryUfunc(self, op):
|
|
||||||
rng = np.random.RandomState(seed=42)
|
|
||||||
x = rng.randn(3, 7, 10).astype(bfloat16)
|
|
||||||
numpy_assert_allclose(
|
|
||||||
op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(({
|
|
||||||
"testcase_name": "_" + op.__name__,
|
|
||||||
"op": op
|
|
||||||
} for op in BINARY_UFUNCS))
|
|
||||||
def testBinaryUfunc(self, op):
|
|
||||||
rng = np.random.RandomState(seed=42)
|
|
||||||
x = rng.randn(3, 7, 10).astype(bfloat16)
|
|
||||||
y = rng.randn(4, 1, 7, 10).astype(bfloat16)
|
|
||||||
numpy_assert_allclose(
|
|
||||||
op(x, y).astype(np.float32),
|
|
||||||
op(x.astype(np.float32), y.astype(np.float32)),
|
|
||||||
rtol=1e-2)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(({
|
|
||||||
"testcase_name": "_" + op.__name__,
|
|
||||||
"op": op
|
|
||||||
} for op in BINARY_PREDICATE_UFUNCS))
|
|
||||||
def testBinaryPredicateUfunc(self, op):
|
|
||||||
rng = np.random.RandomState(seed=42)
|
|
||||||
x = rng.randn(3, 7).astype(bfloat16)
|
|
||||||
y = rng.randn(4, 1, 7).astype(bfloat16)
|
|
||||||
np.testing.assert_equal(
|
|
||||||
op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
|
|
||||||
|
|
||||||
@parameterized.named_parameters(({
|
|
||||||
"testcase_name": "_" + op.__name__,
|
|
||||||
"op": op
|
|
||||||
} for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
|
|
||||||
def testPredicateUfunc(self, op):
|
|
||||||
rng = np.random.RandomState(seed=42)
|
|
||||||
shape = (3, 7, 10)
|
|
||||||
posinf_flips = rng.rand(*shape) < 0.1
|
|
||||||
neginf_flips = rng.rand(*shape) < 0.1
|
|
||||||
nan_flips = rng.rand(*shape) < 0.1
|
|
||||||
vals = rng.randn(*shape)
|
|
||||||
vals = np.where(posinf_flips, np.inf, vals)
|
|
||||||
vals = np.where(neginf_flips, -np.inf, vals)
|
|
||||||
vals = np.where(nan_flips, np.nan, vals)
|
|
||||||
vals = vals.astype(bfloat16)
|
|
||||||
np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
|
|
||||||
|
|
||||||
def testDivmod(self):
|
|
||||||
rng = np.random.RandomState(seed=42)
|
|
||||||
x = rng.randn(3, 7).astype(bfloat16)
|
|
||||||
y = rng.randn(4, 1, 7).astype(bfloat16)
|
|
||||||
o1, o2 = np.divmod(x, y)
|
|
||||||
e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32))
|
|
||||||
numpy_assert_allclose(o1, e1, rtol=1e-2)
|
|
||||||
numpy_assert_allclose(o2, e2, rtol=1e-2)
|
|
||||||
|
|
||||||
def testModf(self):
|
|
||||||
rng = np.random.RandomState(seed=42)
|
|
||||||
x = rng.randn(3, 7).astype(bfloat16)
|
|
||||||
o1, o2 = np.modf(x)
|
|
||||||
e1, e2 = np.modf(x.astype(np.float32))
|
|
||||||
numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2)
|
|
||||||
numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2)
|
|
||||||
|
|
||||||
def testLdexp(self):
|
|
||||||
rng = np.random.RandomState(seed=42)
|
|
||||||
x = rng.randn(3, 7).astype(bfloat16)
|
|
||||||
y = rng.randint(-50, 50, (1, 7))
|
|
||||||
numpy_assert_allclose(
|
|
||||||
np.ldexp(x, y).astype(np.float32),
|
|
||||||
np.ldexp(x.astype(np.float32), y),
|
|
||||||
rtol=1e-2,
|
|
||||||
atol=1e-6)
|
|
||||||
|
|
||||||
def testFrexp(self):
|
|
||||||
rng = np.random.RandomState(seed=42)
|
|
||||||
x = rng.randn(3, 7).astype(bfloat16)
|
|
||||||
mant1, exp1 = np.frexp(x)
|
|
||||||
mant2, exp2 = np.frexp(x.astype(np.float32))
|
|
||||||
np.testing.assert_equal(exp1, exp2)
|
|
||||||
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
|
|
||||||
|
|
||||||
def testNextAfter(self):
|
|
||||||
one = np.array(1., dtype=bfloat16)
|
|
||||||
two = np.array(2., dtype=bfloat16)
|
|
||||||
zero = np.array(0., dtype=bfloat16)
|
|
||||||
nan = np.array(np.nan, dtype=bfloat16)
|
|
||||||
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
|
|
||||||
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
|
|
||||||
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
|
|
||||||
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
|
|
||||||
np.testing.assert_equal(np.nextafter(one, one), one)
|
|
||||||
smallest_denormal = float.fromhex("1.0p-133")
|
|
||||||
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
|
|
||||||
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
|
|
||||||
for a, b in itertools.permutations([0., -0., nan], 2):
|
|
||||||
np.testing.assert_equal(
|
|
||||||
np.nextafter(
|
|
||||||
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
|
|
||||||
np.nextafter(
|
|
||||||
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
absltest.main()
|
|
@ -214,11 +214,9 @@ StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
|
StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
|
||||||
const se::Platform* platform =
|
if (device.client()->platform_id() == kCpuId) {
|
||||||
device.local_device_state()->executor()->platform();
|
|
||||||
if (platform->id() == se::host::kHostPlatformId) {
|
|
||||||
return kDLCPU;
|
return kDLCPU;
|
||||||
} else if (platform->id() == se::cuda::kCudaPlatformId) {
|
} else if (device.client()->platform_id() == kGpuId) {
|
||||||
return kDLGPU;
|
return kDLGPU;
|
||||||
}
|
}
|
||||||
return InvalidArgument("Device %s cannot be used as a DLPack device.",
|
return InvalidArgument("Device %s cannot be used as a DLPack device.",
|
||||||
@ -228,7 +226,7 @@ StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
|
|||||||
StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
|
StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
|
||||||
DLContext context;
|
DLContext context;
|
||||||
TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
|
TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
|
||||||
context.device_id = device.local_device_id();
|
context.device_id = device.local_hardware_id();
|
||||||
return context;
|
return context;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,14 +239,14 @@ StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client,
|
|||||||
"DLPack CPU device type mismatch with PjRtClient platform %s",
|
"DLPack CPU device type mismatch with PjRtClient platform %s",
|
||||||
client.platform_name());
|
client.platform_name());
|
||||||
}
|
}
|
||||||
return client.LookupLocalDevice(context.device_id);
|
return client.LookupAddressableDevice(context.device_id);
|
||||||
case kDLGPU:
|
case kDLGPU:
|
||||||
if (client.platform_id() != kGpuId) {
|
if (client.platform_id() != kGpuId) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"DLPack GPU device type mismatch with PjRtClient platform %s",
|
"DLPack GPU device type mismatch with PjRtClient platform %s",
|
||||||
client.platform_name());
|
client.platform_name());
|
||||||
}
|
}
|
||||||
return client.LookupLocalDevice(context.device_id);
|
return client.LookupAddressableDevice(context.device_id);
|
||||||
default:
|
default:
|
||||||
return InvalidArgument("Unknown/unsupported DLPack device type %d",
|
return InvalidArgument("Unknown/unsupported DLPack device type %d",
|
||||||
context.device_type);
|
context.device_type);
|
||||||
@ -297,7 +295,7 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
|
|||||||
pack->tensor.manager_ctx = pack.get();
|
pack->tensor.manager_ctx = pack.get();
|
||||||
pack->tensor.deleter = DLPackTensorDeleter;
|
pack->tensor.deleter = DLPackTensorDeleter;
|
||||||
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
|
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
|
||||||
dt.ctx.device_id = buffer->buffer()->device()->local_device_id();
|
dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id();
|
||||||
dt.ndim = buffer->buffer()->on_host_shape().dimensions_size();
|
dt.ndim = buffer->buffer()->on_host_shape().dimensions_size();
|
||||||
TF_ASSIGN_OR_RETURN(dt.dtype,
|
TF_ASSIGN_OR_RETURN(dt.dtype,
|
||||||
PrimitiveTypeToDLDataType(
|
PrimitiveTypeToDLDataType(
|
||||||
|
@ -230,8 +230,8 @@ OutfeedReceiverImpl::OutfeedReceiverImpl(
|
|||||||
callback_ = callback;
|
callback_ = callback;
|
||||||
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
|
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
|
||||||
for (const auto& client : clients) {
|
for (const auto& client : clients) {
|
||||||
for (const auto& device : client->devices()) {
|
for (auto device : client->devices()) {
|
||||||
devices_.push_back(device.get());
|
devices_.push_back(device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK_GT(devices_.size(), 0);
|
CHECK_GT(devices_.size(), 0);
|
||||||
@ -342,11 +342,7 @@ StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
|
|||||||
const PjRtDevice* device, const Shape& shape) {
|
const PjRtDevice* device, const Shape& shape) {
|
||||||
std::shared_ptr<Literal> literal_shared;
|
std::shared_ptr<Literal> literal_shared;
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(Literal literal, device->TransferFromOutfeed(shape));
|
||||||
device->GetLocalDeviceState());
|
|
||||||
TF_ASSIGN_OR_RETURN(Literal literal,
|
|
||||||
local_device->client()->TransferFromOutfeedLocal(
|
|
||||||
shape, local_device->device_ordinal()));
|
|
||||||
|
|
||||||
return absl::make_unique<Literal>(std::move(literal));
|
return absl::make_unique<Literal>(std::move(literal));
|
||||||
}
|
}
|
||||||
|
@ -86,8 +86,8 @@ StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<py::dict> PyBuffer::CudaArrayInterface() const {
|
StatusOr<py::dict> PyBuffer::CudaArrayInterface() const {
|
||||||
if (buffer_->device()->local_device_state()->executor()->platform_kind() !=
|
// TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs.
|
||||||
se::PlatformKind::kCuda) {
|
if (buffer_->client()->platform_id() != kGpuId) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"__cuda_array_interface__ is only defined for NVidia GPU buffers.");
|
"__cuda_array_interface__ is only defined for NVidia GPU buffers.");
|
||||||
}
|
}
|
||||||
|
@ -37,9 +37,10 @@ PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
|
|||||||
|
|
||||||
std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
|
std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
|
||||||
std::vector<ClientAndPtr<PjRtDevice>> devices;
|
std::vector<ClientAndPtr<PjRtDevice>> devices;
|
||||||
devices.reserve(pjrt_client_->devices().size());
|
auto span = pjrt_client_->devices();
|
||||||
for (const auto& device : pjrt_client_->devices()) {
|
devices.reserve(span.size());
|
||||||
devices.push_back(WrapWithClient(shared_from_this(), device.get()));
|
for (PjRtDevice* device : span) {
|
||||||
|
devices.push_back(WrapWithClient(shared_from_this(), device));
|
||||||
}
|
}
|
||||||
return devices;
|
return devices;
|
||||||
}
|
}
|
||||||
@ -64,9 +65,9 @@ PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
|
|||||||
result[r].resize(num_partitions);
|
result[r].resize(num_partitions);
|
||||||
for (int p = 0; p < num_partitions; ++p) {
|
for (int p = 0; p < num_partitions; ++p) {
|
||||||
int device_id = device_assignment(r, p);
|
int device_id = device_assignment(r, p);
|
||||||
auto iter = pjrt_client_->id_to_device().find(device_id);
|
TF_ASSIGN_OR_RETURN(PjRtDevice * device,
|
||||||
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
|
pjrt_client_->LookupDevice(device_id));
|
||||||
result[r][p] = WrapWithClient(shared_from_this(), iter->second);
|
result[r][p] = WrapWithClient(shared_from_this(), device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
@ -80,9 +81,9 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
|||||||
std::vector<ClientAndPtr<PjRtDevice>> result;
|
std::vector<ClientAndPtr<PjRtDevice>> result;
|
||||||
for (int i = 0; i < num_replicas; ++i) {
|
for (int i = 0; i < num_replicas; ++i) {
|
||||||
int device_id = device_assignment(i, 0);
|
int device_id = device_assignment(i, 0);
|
||||||
auto iter = pjrt_client_->id_to_device().find(device_id);
|
TF_ASSIGN_OR_RETURN(PjRtDevice * device,
|
||||||
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
|
pjrt_client_->LookupDevice(device_id));
|
||||||
result.push_back(WrapWithClient(shared_from_this(), iter->second));
|
result.push_back(WrapWithClient(shared_from_this(), device));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -95,8 +96,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
|||||||
device = pjrt_client_->local_devices().front();
|
device = pjrt_client_->local_devices().front();
|
||||||
}
|
}
|
||||||
CHECK(device != nullptr);
|
CHECK(device != nullptr);
|
||||||
auto iter = pjrt_client_->id_to_device().find(device->id());
|
TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
|
||||||
if (iter->second != device) {
|
pjrt_client_->LookupDevice(device->id()));
|
||||||
|
if (found_device != device) {
|
||||||
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
|
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
|
||||||
device->DebugString(),
|
device->DebugString(),
|
||||||
pjrt_client_->platform_name());
|
pjrt_client_->platform_name());
|
||||||
|
@ -97,7 +97,9 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
|||||||
const std::string& platform_name() const {
|
const std::string& platform_name() const {
|
||||||
return pjrt_client_->platform_name();
|
return pjrt_client_->platform_name();
|
||||||
}
|
}
|
||||||
int local_device_count() const { return pjrt_client_->local_device_count(); }
|
int addressable_device_count() const {
|
||||||
|
return pjrt_client_->addressable_device_count();
|
||||||
|
}
|
||||||
int device_count() const { return pjrt_client_->device_count(); }
|
int device_count() const { return pjrt_client_->device_count(); }
|
||||||
int host_id() const { return pjrt_client_->host_id(); }
|
int host_id() const { return pjrt_client_->host_id(); }
|
||||||
|
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
#include "pybind11/pytypes.h"
|
#include "pybind11/pytypes.h"
|
||||||
#include "pybind11/stl.h"
|
#include "pybind11/stl.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -106,59 +107,66 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PyTreeDef::FlattenInto(py::handle handle,
|
void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
|
||||||
std::vector<py::object>& leaves) {
|
absl::optional<py::function> leaf_predicate) {
|
||||||
Node node;
|
Node node;
|
||||||
int start_num_nodes = traversal_.size();
|
int start_num_nodes = traversal_.size();
|
||||||
int start_num_leaves = leaves.size();
|
int start_num_leaves = leaves.size();
|
||||||
node.kind = GetKind(handle, &node.custom);
|
if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
|
||||||
if (node.kind == Kind::kNone) {
|
leaves.push_back(py::reinterpret_borrow<py::object>(handle));
|
||||||
// Nothing to do.
|
|
||||||
} else if (node.kind == Kind::kTuple) {
|
|
||||||
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
|
||||||
node.arity = tuple.size();
|
|
||||||
for (py::handle entry : tuple) {
|
|
||||||
FlattenInto(entry, leaves);
|
|
||||||
}
|
|
||||||
} else if (node.kind == Kind::kList) {
|
|
||||||
py::list list = py::reinterpret_borrow<py::list>(handle);
|
|
||||||
node.arity = list.size();
|
|
||||||
for (py::handle entry : list) {
|
|
||||||
FlattenInto(entry, leaves);
|
|
||||||
}
|
|
||||||
} else if (node.kind == Kind::kDict) {
|
|
||||||
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
|
|
||||||
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
|
|
||||||
if (PyList_Sort(keys.ptr())) {
|
|
||||||
throw std::runtime_error("Dictionary key sort failed.");
|
|
||||||
}
|
|
||||||
for (py::handle key : keys) {
|
|
||||||
FlattenInto(dict[key], leaves);
|
|
||||||
}
|
|
||||||
node.arity = dict.size();
|
|
||||||
node.node_data = std::move(keys);
|
|
||||||
} else if (node.kind == Kind::kCustom) {
|
|
||||||
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
|
|
||||||
if (out.size() != 2) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"PyTree custom to_iterable function should return a pair");
|
|
||||||
}
|
|
||||||
node.node_data = out[1];
|
|
||||||
node.arity = 0;
|
|
||||||
for (py::handle entry : py::cast<py::iterable>(out[0])) {
|
|
||||||
++node.arity;
|
|
||||||
FlattenInto(entry, leaves);
|
|
||||||
}
|
|
||||||
} else if (node.kind == Kind::kNamedTuple) {
|
|
||||||
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
|
||||||
node.arity = tuple.size();
|
|
||||||
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
|
|
||||||
for (py::handle entry : tuple) {
|
|
||||||
FlattenInto(entry, leaves);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
assert(node.kind == Kind::kLeaf);
|
node.kind = GetKind(handle, &node.custom);
|
||||||
leaves.push_back(pybind11::reinterpret_borrow<py::object>(handle));
|
auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
|
||||||
|
FlattenInto(child, leaves, leaf_predicate);
|
||||||
|
};
|
||||||
|
if (node.kind == Kind::kNone) {
|
||||||
|
// Nothing to do.
|
||||||
|
} else if (node.kind == Kind::kTuple) {
|
||||||
|
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
||||||
|
node.arity = tuple.size();
|
||||||
|
for (py::handle entry : tuple) {
|
||||||
|
recurse(entry);
|
||||||
|
}
|
||||||
|
} else if (node.kind == Kind::kList) {
|
||||||
|
py::list list = py::reinterpret_borrow<py::list>(handle);
|
||||||
|
node.arity = list.size();
|
||||||
|
for (py::handle entry : list) {
|
||||||
|
recurse(entry);
|
||||||
|
}
|
||||||
|
} else if (node.kind == Kind::kDict) {
|
||||||
|
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
|
||||||
|
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
|
||||||
|
if (PyList_Sort(keys.ptr())) {
|
||||||
|
throw std::runtime_error("Dictionary key sort failed.");
|
||||||
|
}
|
||||||
|
for (py::handle key : keys) {
|
||||||
|
recurse(dict[key]);
|
||||||
|
}
|
||||||
|
node.arity = dict.size();
|
||||||
|
node.node_data = std::move(keys);
|
||||||
|
} else if (node.kind == Kind::kCustom) {
|
||||||
|
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
|
||||||
|
if (out.size() != 2) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"PyTree custom to_iterable function should return a pair");
|
||||||
|
}
|
||||||
|
node.node_data = out[1];
|
||||||
|
node.arity = 0;
|
||||||
|
for (py::handle entry : py::cast<py::iterable>(out[0])) {
|
||||||
|
++node.arity;
|
||||||
|
recurse(entry);
|
||||||
|
}
|
||||||
|
} else if (node.kind == Kind::kNamedTuple) {
|
||||||
|
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
||||||
|
node.arity = tuple.size();
|
||||||
|
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
|
||||||
|
for (py::handle entry : tuple) {
|
||||||
|
recurse(entry);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert(node.kind == Kind::kLeaf);
|
||||||
|
leaves.push_back(py::reinterpret_borrow<py::object>(handle));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
node.num_nodes = traversal_.size() - start_num_nodes + 1;
|
node.num_nodes = traversal_.size() - start_num_nodes + 1;
|
||||||
node.num_leaves = leaves.size() - start_num_leaves;
|
node.num_leaves = leaves.size() - start_num_leaves;
|
||||||
@ -166,10 +174,10 @@ void PyTreeDef::FlattenInto(py::handle handle,
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
|
/*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
|
||||||
PyTreeDef::Flatten(py::handle x) {
|
PyTreeDef::Flatten(py::handle x, absl::optional<py::function> leaf_predicate) {
|
||||||
std::vector<py::object> leaves;
|
std::vector<py::object> leaves;
|
||||||
auto tree = absl::make_unique<PyTreeDef>();
|
auto tree = absl::make_unique<PyTreeDef>();
|
||||||
tree->FlattenInto(x, leaves);
|
tree->FlattenInto(x, leaves, leaf_predicate);
|
||||||
return std::make_pair(std::move(leaves), std::move(tree));
|
return std::make_pair(std::move(leaves), std::move(tree));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -618,7 +626,8 @@ std::string PyTreeDef::ToString() const {
|
|||||||
|
|
||||||
void BuildPytreeSubmodule(py::module& m) {
|
void BuildPytreeSubmodule(py::module& m) {
|
||||||
py::module pytree = m.def_submodule("pytree", "Python tree library");
|
py::module pytree = m.def_submodule("pytree", "Python tree library");
|
||||||
pytree.def("flatten", &PyTreeDef::Flatten);
|
pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
|
||||||
|
py::arg("leaf_predicate") = absl::nullopt);
|
||||||
pytree.def("tuple", &PyTreeDef::Tuple);
|
pytree.def("tuple", &PyTreeDef::Tuple);
|
||||||
pytree.def("all_leaves", &PyTreeDef::AllLeaves);
|
pytree.def("all_leaves", &PyTreeDef::AllLeaves);
|
||||||
|
|
||||||
|
@ -85,11 +85,13 @@ class PyTreeDef {
|
|||||||
|
|
||||||
// Flattens a Pytree into a list of leaves and a PyTreeDef.
|
// Flattens a Pytree into a list of leaves and a PyTreeDef.
|
||||||
static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
|
static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
|
||||||
Flatten(pybind11::handle x);
|
Flatten(pybind11::handle x,
|
||||||
|
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
|
||||||
|
|
||||||
// Recursive helper used to implement Flatten().
|
// Recursive helper used to implement Flatten().
|
||||||
void FlattenInto(pybind11::handle handle,
|
void FlattenInto(
|
||||||
std::vector<pybind11::object>& leaves);
|
pybind11::handle handle, std::vector<pybind11::object>& leaves,
|
||||||
|
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
|
||||||
|
|
||||||
// Tests whether the given list is a flat list of leaves.
|
// Tests whether the given list is a flat list of leaves.
|
||||||
static bool AllLeaves(const pybind11::iterable& x);
|
static bool AllLeaves(const pybind11::iterable& x);
|
||||||
|
@ -37,6 +37,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:computation_placer",
|
"//tensorflow/compiler/xla/service:computation_placer",
|
||||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||||
"//tensorflow/core/framework:allocator",
|
"//tensorflow/core/framework:allocator",
|
||||||
|
"//tensorflow/core/platform:casts",
|
||||||
"//tensorflow/core/platform:env",
|
"//tensorflow/core/platform:env",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
@ -37,8 +37,8 @@ namespace xla {
|
|||||||
|
|
||||||
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||||
int core_on_chip)
|
int core_on_chip)
|
||||||
: xla::PjRtDevice(id, /*local_device_state=*/nullptr,
|
: xla::PjRtStreamExecutorDevice(id, /*local_device_state=*/nullptr,
|
||||||
/*device_kind=*/"Cloud TPU", host_id),
|
/*device_kind=*/"Cloud TPU", host_id),
|
||||||
coords_(coords),
|
coords_(coords),
|
||||||
core_on_chip_(core_on_chip) {}
|
core_on_chip_(core_on_chip) {}
|
||||||
|
|
||||||
@ -531,7 +531,7 @@ PyTpuExecutable::PyTpuExecutable(
|
|||||||
<< "Inserting duplicate replica:" << replica;
|
<< "Inserting duplicate replica:" << replica;
|
||||||
executables_[replica] =
|
executables_[replica] =
|
||||||
client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
|
client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
|
||||||
addressable_device_logical_ids_.emplace_back(replica, partition);
|
local_logical_device_ids_.emplace_back(replica, partition);
|
||||||
local_devices_.push_back(device);
|
local_devices_.push_back(device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -711,8 +711,8 @@ PyTpuExecutable::ExecuteOnLocalDevices(
|
|||||||
// long time and we want all cores to be scheduled in parallel.
|
// long time and we want all cores to be scheduled in parallel.
|
||||||
thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
|
thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
|
||||||
&execute_semaphore]() {
|
&execute_semaphore]() {
|
||||||
const int replica = addressable_device_logical_ids_[i].first;
|
const int replica = local_logical_device_ids_[i].first;
|
||||||
const int partition = addressable_device_logical_ids_[i].second;
|
const int partition = local_logical_device_ids_[i].second;
|
||||||
RunId run_id;
|
RunId run_id;
|
||||||
auto result = ExecuteHelper(argument_handles, argument_handles[i],
|
auto result = ExecuteHelper(argument_handles, argument_handles[i],
|
||||||
replica, partition, run_id);
|
replica, partition, run_id);
|
||||||
|
@ -32,13 +32,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/threadpool.h"
|
#include "tensorflow/core/platform/threadpool.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
constexpr char kTpuPlatform[] = "tpu";
|
constexpr char kTpuPlatform[] = "tpu";
|
||||||
|
|
||||||
class TpuDevice : public PjRtDevice {
|
class TpuDevice : public PjRtStreamExecutorDevice {
|
||||||
public:
|
public:
|
||||||
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||||
int core_on_chip);
|
int core_on_chip);
|
||||||
@ -298,9 +299,8 @@ class PyTpuExecutable {
|
|||||||
return device_assignment_;
|
return device_assignment_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<std::pair<int, int>>& addressable_device_logical_ids()
|
const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
|
||||||
const {
|
return local_logical_device_ids_;
|
||||||
return addressable_device_logical_ids_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
|
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
|
||||||
@ -341,14 +341,16 @@ class PyTpuExecutable {
|
|||||||
|
|
||||||
// The replica and partition indices of device_assignment_ to be run by this
|
// The replica and partition indices of device_assignment_ to be run by this
|
||||||
// client. On single-host platforms without partitioning, this is all replicas
|
// client. On single-host platforms without partitioning, this is all replicas
|
||||||
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
|
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
|
||||||
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
|
// on multi-host platforms.
|
||||||
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
|
// If there are 4 replicas and 2 partitions on a single host platform, size of
|
||||||
std::vector<std::pair<int, int>> addressable_device_logical_ids_;
|
// local_logical_device_ids_ is 4*2 = 8.
|
||||||
|
std::vector<std::pair<int, int>> local_logical_device_ids_;
|
||||||
|
|
||||||
// local_devices_[i] is the Device to which addressable_device_logical_ids_[i]
|
// local_devices_[i] is the Device to which local_logical_device_ids_[i] is
|
||||||
// is assigned. shared_ptrs instead of unique_ptrs to play well with the
|
// assigned.
|
||||||
// Python bindings (see xla.cc).
|
// shared_ptrs instead of unique_ptrs to play well with the Python bindings
|
||||||
|
// (see xla.cc).
|
||||||
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
|
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
|
||||||
|
|
||||||
xla::Shape result_shape_;
|
xla::Shape result_shape_;
|
||||||
|
@ -186,7 +186,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
|||||||
|
|
||||||
py::class_<PyTpuExecutable>(m, "TpuExecutable")
|
py::class_<PyTpuExecutable>(m, "TpuExecutable")
|
||||||
.def("local_logical_device_ids",
|
.def("local_logical_device_ids",
|
||||||
&PyTpuExecutable::addressable_device_logical_ids)
|
&PyTpuExecutable::local_logical_device_ids)
|
||||||
.def("local_devices", &PyTpuExecutable::local_devices)
|
.def("local_devices", &PyTpuExecutable::local_devices)
|
||||||
.def_property_readonly("client", &PyTpuExecutable::client)
|
.def_property_readonly("client", &PyTpuExecutable::client)
|
||||||
.def("size_of_generated_code_in_bytes",
|
.def("size_of_generated_code_in_bytes",
|
||||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/python/types.h"
|
#include "tensorflow/compiler/xla/python/types.h"
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/python/bfloat16.h"
|
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -81,8 +81,8 @@ xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
|
|||||||
case U64:
|
case U64:
|
||||||
return py::dtype::of<uint64>();
|
return py::dtype::of<uint64>();
|
||||||
case BF16: {
|
case BF16: {
|
||||||
TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype());
|
py::handle bfloat16(tensorflow::Bfloat16Dtype());
|
||||||
return py::dtype::from_args(bfloat16);
|
return py::dtype::from_args(py::reinterpret_borrow<py::object>(bfloat16));
|
||||||
}
|
}
|
||||||
case F16:
|
case F16:
|
||||||
return py::dtype("e"); // PEP 3118 code for "float16
|
return py::dtype("e"); // PEP 3118 code for "float16
|
||||||
@ -237,10 +237,11 @@ StatusOr<py::object> LiteralToPython(std::shared_ptr<xla::Literal> literal) {
|
|||||||
// We requested an array of uint16 since NumPy doesn't know how
|
// We requested an array of uint16 since NumPy doesn't know how
|
||||||
// to produce our custom bfloat16 type. Reinterpret the array as bfloat16
|
// to produce our custom bfloat16 type. Reinterpret the array as bfloat16
|
||||||
// before handing it back to the caller.
|
// before handing it back to the caller.
|
||||||
TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype());
|
py::handle bfloat16(tensorflow::Bfloat16Dtype());
|
||||||
|
bfloat16.inc_ref();
|
||||||
array = py::reinterpret_steal<py::array>(
|
array = py::reinterpret_steal<py::array>(
|
||||||
PyArray_View(reinterpret_cast<PyArrayObject*>(array.ptr()),
|
PyArray_View(reinterpret_cast<PyArrayObject*>(array.ptr()),
|
||||||
reinterpret_cast<PyArray_Descr*>(bfloat16.release().ptr()),
|
reinterpret_cast<PyArray_Descr*>(bfloat16.ptr()),
|
||||||
static_cast<PyTypeObject*>(nullptr)));
|
static_cast<PyTypeObject*>(nullptr)));
|
||||||
}
|
}
|
||||||
return array;
|
return array;
|
||||||
|
@ -40,7 +40,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
|
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
|
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
|
||||||
#include "tensorflow/compiler/xla/python/bfloat16.h"
|
|
||||||
#include "tensorflow/compiler/xla/python/dlpack.h"
|
#include "tensorflow/compiler/xla/python/dlpack.h"
|
||||||
#include "tensorflow/compiler/xla/python/jax_jit.h"
|
#include "tensorflow/compiler/xla/python/jax_jit.h"
|
||||||
#include "tensorflow/compiler/xla/python/ops.h"
|
#include "tensorflow/compiler/xla/python/ops.h"
|
||||||
@ -59,6 +58,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -110,6 +110,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
throw std::runtime_error("Unable to initialize Numpy API");
|
throw std::runtime_error("Unable to initialize Numpy API");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CHECK(tensorflow::RegisterNumpyBfloat16());
|
||||||
|
|
||||||
// Types
|
// Types
|
||||||
py::enum_<PrimitiveType>(m, "PrimitiveType")
|
py::enum_<PrimitiveType>(m, "PrimitiveType")
|
||||||
.value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
|
.value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
|
||||||
@ -132,7 +134,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.value("OPAQUE_TYPE", OPAQUE_TYPE)
|
.value("OPAQUE_TYPE", OPAQUE_TYPE)
|
||||||
.value("TOKEN", TOKEN);
|
.value("TOKEN", TOKEN);
|
||||||
|
|
||||||
m.def("bfloat16_dtype", Bfloat16Dtype);
|
m.def("bfloat16_dtype",
|
||||||
|
[]() { return py::handle(tensorflow::Bfloat16Dtype()); });
|
||||||
|
|
||||||
// Must be before PyClient.compile.
|
// Must be before PyClient.compile.
|
||||||
BuildXlaCompilerSubmodule(m);
|
BuildXlaCompilerSubmodule(m);
|
||||||
@ -149,7 +152,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.def_property_readonly("host_id", &PjRtDevice::host_id,
|
.def_property_readonly("host_id", &PjRtDevice::host_id,
|
||||||
"Integer ID of this device's host.\n\n"
|
"Integer ID of this device's host.\n\n"
|
||||||
"This is always 0 except on multi-host platforms.")
|
"This is always 0 except on multi-host platforms.")
|
||||||
.def_property_readonly("platform", &PjRtDevice::platform_name)
|
.def_property_readonly("platform",
|
||||||
|
[](const PjRtDevice& device) {
|
||||||
|
return device.client()->platform_name();
|
||||||
|
})
|
||||||
.def_property_readonly("device_kind", &PjRtDevice::device_kind)
|
.def_property_readonly("device_kind", &PjRtDevice::device_kind)
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"client",
|
"client",
|
||||||
@ -234,7 +240,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
||||||
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
||||||
.def("device_count", &PyClient::device_count)
|
.def("device_count", &PyClient::device_count)
|
||||||
.def("local_device_count", &PyClient::local_device_count)
|
.def("local_device_count", &PyClient::addressable_device_count)
|
||||||
.def("devices", &PyClient::Devices)
|
.def("devices", &PyClient::Devices)
|
||||||
.def("local_devices", &PyClient::LocalDevices)
|
.def("local_devices", &PyClient::LocalDevices)
|
||||||
.def("host_id", &PyClient::host_id)
|
.def("host_id", &PyClient::host_id)
|
||||||
@ -381,10 +387,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
[](PyExecutable* exec) {
|
[](PyExecutable* exec) {
|
||||||
auto span = exec->addressable_device_logical_ids();
|
auto span = exec->addressable_device_logical_ids();
|
||||||
// Not on dispatch critical path, so ok to have heap allocation.
|
// Not on dispatch critical path, so ok to have heap allocation.
|
||||||
std::vector<std::pair<int, int>> addressable_device_logical_ids;
|
std::vector<std::pair<int, int>> addressable_device_logic_ids;
|
||||||
addressable_device_logical_ids.reserve(span.size());
|
addressable_device_logic_ids.reserve(span.size());
|
||||||
for (const auto& logical_device_id : span) {
|
for (const auto& logical_device_id : span) {
|
||||||
addressable_device_logical_ids.push_back(std::make_pair(
|
addressable_device_logic_ids.push_back(std::make_pair(
|
||||||
logical_device_id.replica, logical_device_id.partition));
|
logical_device_id.replica, logical_device_id.partition));
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -42,6 +42,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
|
#include "tensorflow/core/platform/threadpool.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -158,6 +159,18 @@ class AotCompilationMetadata {
|
|||||||
// platform.
|
// platform.
|
||||||
class Compiler {
|
class Compiler {
|
||||||
public:
|
public:
|
||||||
|
struct CompileOptions {
|
||||||
|
// If device_allocator is not null, the compiler may use it to allocate temp
|
||||||
|
// space on the device for use during compilation. For example, the
|
||||||
|
// compiler may allocate buffers on the device and then run variants of a
|
||||||
|
// given algorithm over those buffers, to see which variant is fastest. Any
|
||||||
|
// space allocated will be deallocated before the compilation returns.
|
||||||
|
se::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||||
|
|
||||||
|
// An optional thread pool for parallel compilation.
|
||||||
|
tensorflow::thread::ThreadPool* thread_pool = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
virtual ~Compiler() {}
|
virtual ~Compiler() {}
|
||||||
|
|
||||||
// Returns the ID of the platform that this compiler targets.
|
// Returns the ID of the platform that this compiler targets.
|
||||||
@ -165,31 +178,24 @@ class Compiler {
|
|||||||
|
|
||||||
// Runs Hlo passes to optimize the given Hlo module, returns the optimized
|
// Runs Hlo passes to optimize the given Hlo module, returns the optimized
|
||||||
// module.
|
// module.
|
||||||
//
|
|
||||||
// If device_allocator is not null, the compiler may use it to allocate temp
|
|
||||||
// space on the device for use during compilation. For example, the compiler
|
|
||||||
// may allocate buffers on the device and then run variants of a given
|
|
||||||
// algorithm over those buffers, to see which variant is fastest. Any space
|
|
||||||
// allocated should be deallocated before this function returns.
|
|
||||||
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||||
se::DeviceMemoryAllocator* device_allocator) = 0;
|
const CompileOptions& options) = 0;
|
||||||
|
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||||
|
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||||
|
se::DeviceMemoryAllocator* device_allocator) {
|
||||||
|
return RunHloPasses(std::move(module), executor,
|
||||||
|
CompileOptions{device_allocator});
|
||||||
|
}
|
||||||
|
|
||||||
// Runs HLO passes to optimize the given HloModule, perform scheduling and
|
// Runs HLO passes to optimize the given HloModule, perform scheduling and
|
||||||
// buffer assignment, returns the optimized module and the buffer assignments.
|
// buffer assignment, returns the optimized module and the buffer assignments.
|
||||||
// This interface is intentionally narrow.
|
// This interface is intentionally narrow.
|
||||||
//
|
|
||||||
// If device_allocator is not null, the compiler may use it to allocate temp
|
|
||||||
// space on the device for use during compilation. For example, the compiler
|
|
||||||
// may allocate buffers on the device and then run variants of a given
|
|
||||||
// algorithm over those buffers, to see which variant is fastest. Any space
|
|
||||||
// allocated should be deallocated before this function returns.
|
|
||||||
virtual StatusOr<
|
virtual StatusOr<
|
||||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||||
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
|
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
|
||||||
se::StreamExecutor* executor,
|
se::StreamExecutor* executor, bool optimize,
|
||||||
se::DeviceMemoryAllocator* device_allocator,
|
const CompileOptions& options) {
|
||||||
bool optimize) {
|
|
||||||
return Unimplemented("This compiler does not support this method");
|
return Unimplemented("This compiler does not support this method");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -201,24 +207,33 @@ class Compiler {
|
|||||||
//
|
//
|
||||||
// The compiler may optionally specialize to the individual device
|
// The compiler may optionally specialize to the individual device
|
||||||
// (not just type of device) indicated by the executor.
|
// (not just type of device) indicated by the executor.
|
||||||
//
|
|
||||||
// device_allocator is optional; see RunHloPasses.
|
|
||||||
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
|
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||||
se::DeviceMemoryAllocator* device_allocator) = 0;
|
const CompileOptions& options) = 0;
|
||||||
|
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||||
|
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||||
|
se::DeviceMemoryAllocator* device_allocator) {
|
||||||
|
return RunBackend(std::move(module), executor,
|
||||||
|
CompileOptions{device_allocator});
|
||||||
|
}
|
||||||
|
|
||||||
// Compiles a set of HLO modules that can run in parallel, potentially
|
// Compiles a set of HLO modules that can run in parallel, potentially
|
||||||
// communicating data between the modules, and returns a corresponding
|
// communicating data between the modules, and returns a corresponding
|
||||||
// sequence of executable objects.
|
// sequence of executable objects.
|
||||||
//
|
//
|
||||||
// device_allocator is optional; see RunHloPasses.
|
|
||||||
//
|
|
||||||
// TODO(b/68666782): Remove this method after adding support for multiple
|
// TODO(b/68666782): Remove this method after adding support for multiple
|
||||||
// modules to RunHloPasses and RunBackends.
|
// modules to RunHloPasses and RunBackends.
|
||||||
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||||
std::unique_ptr<HloModuleGroup> module_group,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) = 0;
|
const CompileOptions& options) = 0;
|
||||||
|
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||||
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
|
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||||
|
se::DeviceMemoryAllocator* device_allocator) {
|
||||||
|
return Compile(std::move(module_group), stream_exec,
|
||||||
|
CompileOptions{device_allocator});
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the backend configurations that the backend will consider for the
|
// Returns the backend configurations that the backend will consider for the
|
||||||
// given HLO. Returns no configurations if the backend does not support
|
// given HLO. Returns no configurations if the backend does not support
|
||||||
|
@ -553,7 +553,7 @@ Status CreateHloProfilingArtifacts(
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
|
||||||
se::DeviceMemoryAllocator* /*device_allocator*/) {
|
const CompileOptions& /*options*/) {
|
||||||
std::unique_ptr<llvm::TargetMachine> jit_target_machine =
|
std::unique_ptr<llvm::TargetMachine> jit_target_machine =
|
||||||
SimpleOrcJIT::InferTargetMachineForJIT(
|
SimpleOrcJIT::InferTargetMachineForJIT(
|
||||||
CompilerTargetOptions(module->config()),
|
CompilerTargetOptions(module->config()),
|
||||||
@ -566,12 +566,13 @@ StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
|
|||||||
|
|
||||||
StatusOr<
|
StatusOr<
|
||||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||||
CpuCompiler::RunHloPassesAndBufferAssignement(
|
CpuCompiler::RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
se::StreamExecutor* executor,
|
||||||
se::DeviceMemoryAllocator* device_allocator, bool optimize) {
|
bool optimize,
|
||||||
|
const CompileOptions& options) {
|
||||||
if (optimize) {
|
if (optimize) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(module,
|
||||||
module, RunHloPasses(std::move(module), executor, device_allocator));
|
RunHloPasses(std::move(module), executor, options));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select an order for emitting the HLO instructions for each computation.
|
// Select an order for emitting the HLO instructions for each computation.
|
||||||
@ -632,7 +633,7 @@ struct OrcJITPostCompilationHook {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* /*device_allocator*/) {
|
const CompileOptions& options) {
|
||||||
VLOG(1) << "Compiling: " << module->name();
|
VLOG(1) << "Compiling: " << module->name();
|
||||||
XLA_SCOPED_LOGGING_TIMER(
|
XLA_SCOPED_LOGGING_TIMER(
|
||||||
absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
|
absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
|
||||||
|
@ -134,18 +134,17 @@ class CpuCompiler : public LLVMCompiler {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
|
|
||||||
StatusOr<
|
StatusOr<
|
||||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||||
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
|
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
|
||||||
se::StreamExecutor* executor,
|
se::StreamExecutor* executor, bool optimize,
|
||||||
se::DeviceMemoryAllocator* device_allocator,
|
const CompileOptions& options) override;
|
||||||
bool optimize) override;
|
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
|
@ -440,7 +440,7 @@ filegroup(
|
|||||||
name = "nccl_collective_thunk_src",
|
name = "nccl_collective_thunk_src",
|
||||||
srcs = if_nccl(
|
srcs = if_nccl(
|
||||||
["nccl_collective_thunk.cc"],
|
["nccl_collective_thunk.cc"],
|
||||||
["dummy_collective_thunk.cc"],
|
["nccl_collective_thunk_dummy.cc"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -448,7 +448,7 @@ tf_cuda_library(
|
|||||||
name = "nccl_collective_thunk",
|
name = "nccl_collective_thunk",
|
||||||
srcs = if_cuda_or_rocm(
|
srcs = if_cuda_or_rocm(
|
||||||
[":nccl_collective_thunk_src"],
|
[":nccl_collective_thunk_src"],
|
||||||
["dummy_collective_thunk.cc"],
|
["nccl_collective_thunk_dummy.cc"],
|
||||||
),
|
),
|
||||||
hdrs = ["nccl_collective_thunk.h"],
|
hdrs = ["nccl_collective_thunk.h"],
|
||||||
deps = [
|
deps = [
|
||||||
@ -480,7 +480,7 @@ filegroup(
|
|||||||
name = "nccl_all_gather_thunk_src",
|
name = "nccl_all_gather_thunk_src",
|
||||||
srcs = if_nccl(
|
srcs = if_nccl(
|
||||||
["nccl_all_gather_thunk.cc"],
|
["nccl_all_gather_thunk.cc"],
|
||||||
["dummy_all_gather_thunk.cc"],
|
["nccl_all_gather_thunk_dummy.cc"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -488,7 +488,7 @@ tf_cuda_library(
|
|||||||
name = "nccl_all_gather_thunk",
|
name = "nccl_all_gather_thunk",
|
||||||
srcs = if_cuda_or_rocm(
|
srcs = if_cuda_or_rocm(
|
||||||
[":nccl_all_gather_thunk_src"],
|
[":nccl_all_gather_thunk_src"],
|
||||||
["dummy_all_gather_thunk.cc"],
|
["nccl_all_gather_thunk_dummy.cc"],
|
||||||
),
|
),
|
||||||
hdrs = ["nccl_all_gather_thunk.h"],
|
hdrs = ["nccl_all_gather_thunk.h"],
|
||||||
deps = [
|
deps = [
|
||||||
@ -520,7 +520,7 @@ filegroup(
|
|||||||
name = "nccl_all_reduce_thunk_src",
|
name = "nccl_all_reduce_thunk_src",
|
||||||
srcs = if_nccl(
|
srcs = if_nccl(
|
||||||
["nccl_all_reduce_thunk.cc"],
|
["nccl_all_reduce_thunk.cc"],
|
||||||
["dummy_all_reduce_thunk.cc"],
|
["nccl_all_reduce_thunk_dummy.cc"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -528,7 +528,7 @@ tf_cuda_library(
|
|||||||
name = "nccl_all_reduce_thunk",
|
name = "nccl_all_reduce_thunk",
|
||||||
srcs = if_cuda_or_rocm(
|
srcs = if_cuda_or_rocm(
|
||||||
[":nccl_all_reduce_thunk_src"],
|
[":nccl_all_reduce_thunk_src"],
|
||||||
["dummy_all_reduce_thunk.cc"],
|
["nccl_all_reduce_thunk_dummy.cc"],
|
||||||
),
|
),
|
||||||
hdrs = ["nccl_all_reduce_thunk.h"],
|
hdrs = ["nccl_all_reduce_thunk.h"],
|
||||||
deps = [
|
deps = [
|
||||||
@ -560,7 +560,7 @@ filegroup(
|
|||||||
name = "nccl_all_to_all_thunk_src",
|
name = "nccl_all_to_all_thunk_src",
|
||||||
srcs = if_nccl(
|
srcs = if_nccl(
|
||||||
["nccl_all_to_all_thunk.cc"],
|
["nccl_all_to_all_thunk.cc"],
|
||||||
["dummy_all_to_all_thunk.cc"],
|
["nccl_all_to_all_thunk_dummy.cc"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -568,7 +568,7 @@ tf_cuda_library(
|
|||||||
name = "nccl_all_to_all_thunk",
|
name = "nccl_all_to_all_thunk",
|
||||||
srcs = if_cuda_or_rocm(
|
srcs = if_cuda_or_rocm(
|
||||||
[":nccl_all_to_all_thunk_src"],
|
[":nccl_all_to_all_thunk_src"],
|
||||||
["dummy_all_to_all_thunk.cc"],
|
["nccl_all_to_all_thunk_dummy.cc"],
|
||||||
),
|
),
|
||||||
hdrs = ["nccl_all_to_all_thunk.h"],
|
hdrs = ["nccl_all_to_all_thunk.h"],
|
||||||
deps = [
|
deps = [
|
||||||
@ -600,7 +600,7 @@ filegroup(
|
|||||||
name = "nccl_test_utils_src",
|
name = "nccl_test_utils_src",
|
||||||
srcs = if_nccl(
|
srcs = if_nccl(
|
||||||
["nccl_test_utils.cc"],
|
["nccl_test_utils.cc"],
|
||||||
["dummy_nccl_test_utils.cc"],
|
["nccl_test_utils_dummy.cc"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -608,7 +608,7 @@ tf_cuda_library(
|
|||||||
name = "nccl_test_utils",
|
name = "nccl_test_utils",
|
||||||
srcs = if_cuda_or_rocm(
|
srcs = if_cuda_or_rocm(
|
||||||
[":nccl_test_utils_src"],
|
[":nccl_test_utils_src"],
|
||||||
["dummy_nccl_test_utils.cc"],
|
["nccl_test_utils_dummy.cc"],
|
||||||
),
|
),
|
||||||
hdrs = ["nccl_test_utils.h"],
|
hdrs = ["nccl_test_utils.h"],
|
||||||
deps = [
|
deps = [
|
||||||
@ -1452,7 +1452,11 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor:stream_executor_headers",
|
"//tensorflow/stream_executor:stream_executor_headers",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@llvm-project//llvm:AsmParser",
|
||||||
|
"@llvm-project//llvm:BitReader",
|
||||||
|
"@llvm-project//llvm:BitWriter",
|
||||||
"@llvm-project//llvm:Core",
|
"@llvm-project//llvm:Core",
|
||||||
|
"@llvm-project//llvm:TransformUtils",
|
||||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
],
|
],
|
||||||
@ -1517,7 +1521,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor:stream_executor_headers",
|
"//tensorflow/stream_executor:stream_executor_headers",
|
||||||
"//tensorflow/stream_executor/cuda:cuda_diagnostics",
|
"//tensorflow/stream_executor/cuda:cuda_diagnostics",
|
||||||
"//tensorflow/stream_executor/gpu:asm_compiler",
|
"//tensorflow/stream_executor/gpu:asm_compiler",
|
||||||
]),
|
]) + ["//tensorflow/stream_executor/gpu:gpu_driver_header"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -108,12 +108,17 @@ StatusOr<std::pair<std::string, std::vector<uint8>>>
|
|||||||
AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
|
AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
|
||||||
llvm::Module* llvm_module,
|
llvm::Module* llvm_module,
|
||||||
GpuVersion gpu_version,
|
GpuVersion gpu_version,
|
||||||
se::StreamExecutor* stream_exec) {
|
se::StreamExecutor* stream_exec,
|
||||||
|
bool relocatable) {
|
||||||
if (rocdl_dir_.empty()) {
|
if (rocdl_dir_.empty()) {
|
||||||
// Compute rocdl_dir_ just once and cache it in this member.
|
// Compute rocdl_dir_ just once and cache it in this member.
|
||||||
rocdl_dir_ = GetROCDLDir(module->config());
|
rocdl_dir_ = GetROCDLDir(module->config());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (relocatable) {
|
||||||
|
return Unimplemented("relocatable target binary is not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<uint8> hsaco;
|
std::vector<uint8> hsaco;
|
||||||
{
|
{
|
||||||
XLA_SCOPED_LOGGING_TIMER(
|
XLA_SCOPED_LOGGING_TIMER(
|
||||||
|
@ -41,7 +41,8 @@ class AMDGPUCompiler : public GpuCompiler {
|
|||||||
|
|
||||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
||||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
|
GpuVersion gpu_version, se::StreamExecutor* stream_exec,
|
||||||
|
bool relocatable) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The parent directory of ROCm-Device-Libs IR libraries.
|
// The parent directory of ROCm-Device-Libs IR libraries.
|
||||||
|
@ -24,11 +24,15 @@ limitations under the License.
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "llvm/AsmParser/Parser.h"
|
||||||
|
#include "llvm/Bitcode/BitcodeReader.h"
|
||||||
|
#include "llvm/Bitcode/BitcodeWriter.h"
|
||||||
#include "llvm/IR/DiagnosticInfo.h"
|
#include "llvm/IR/DiagnosticInfo.h"
|
||||||
#include "llvm/IR/DiagnosticPrinter.h"
|
#include "llvm/IR/DiagnosticPrinter.h"
|
||||||
#include "llvm/IR/LLVMContext.h"
|
#include "llvm/IR/LLVMContext.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/Verifier.h"
|
#include "llvm/IR/Verifier.h"
|
||||||
|
#include "llvm/Transforms/Utils/SplitModule.h"
|
||||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||||
#include "mlir/InitAllDialects.h" // from @llvm-project
|
#include "mlir/InitAllDialects.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||||
@ -114,11 +118,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/platform/blocking_counter.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/regexp.h"
|
#include "tensorflow/core/platform/regexp.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
#include "tensorflow/core/platform/subprocess.h"
|
#include "tensorflow/core/platform/subprocess.h"
|
||||||
|
#include "tensorflow/core/platform/threadpool.h"
|
||||||
#include "tensorflow/core/platform/tracing.h"
|
#include "tensorflow/core/platform/tracing.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
#include "tensorflow/core/util/env_var.h"
|
#include "tensorflow/core/util/env_var.h"
|
||||||
@ -470,14 +476,14 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment(
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) {
|
const CompileOptions& options) {
|
||||||
// We dump the post-optimization HLO in RunBackend so no need to dump it here.
|
// We dump the post-optimization HLO in RunBackend so no need to dump it here.
|
||||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
|
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
|
||||||
tensorflow::profiler::TraceMe activity(
|
tensorflow::profiler::TraceMe activity(
|
||||||
[&] { return absl::StrCat("HLO Transforms:", module->name()); },
|
[&] { return absl::StrCat("HLO Transforms:", module->name()); },
|
||||||
tensorflow::profiler::TraceMeLevel::kInfo);
|
tensorflow::profiler::TraceMeLevel::kInfo);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
OptimizeHloModule(module.get(), stream_exec, device_allocator));
|
OptimizeHloModule(module.get(), stream_exec, options.device_allocator));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
|
TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
|
||||||
|
|
||||||
@ -494,10 +500,10 @@ StatusOr<
|
|||||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||||
GpuCompiler::RunHloPassesAndBufferAssignement(
|
GpuCompiler::RunHloPassesAndBufferAssignement(
|
||||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* executor,
|
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* executor,
|
||||||
se::DeviceMemoryAllocator* device_allocator, bool optimize) {
|
bool optimize, const CompileOptions& options) {
|
||||||
if (optimize) {
|
if (optimize) {
|
||||||
TF_ASSIGN_OR_RETURN(hlo_module, RunHloPasses(std::move(hlo_module),
|
TF_ASSIGN_OR_RETURN(hlo_module,
|
||||||
executor, device_allocator));
|
RunHloPasses(std::move(hlo_module), executor, options));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<StreamAssignment> stream_assignment =
|
std::unique_ptr<StreamAssignment> stream_assignment =
|
||||||
@ -641,24 +647,133 @@ static Status CompileModuleToLlvmIrImpl(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<std::pair<std::string, std::vector<uint8>>>
|
||||||
|
GpuCompiler::CompileToTargetBinary(const HloModule& module,
|
||||||
|
std::unique_ptr<llvm::Module> llvm_module,
|
||||||
|
se::StreamExecutor* stream_exec,
|
||||||
|
const CompileOptions& options) {
|
||||||
|
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
|
||||||
|
|
||||||
|
const auto compile_single_module =
|
||||||
|
[this, stream_exec, &module](
|
||||||
|
llvm::Module* llvm_module,
|
||||||
|
bool relocatable) -> StatusOr<BackendCompileResult> {
|
||||||
|
{
|
||||||
|
XLA_SCOPED_LOGGING_TIMER(
|
||||||
|
"GpuCompiler::RunBackend - Running LLVM verifier");
|
||||||
|
|
||||||
|
std::string err;
|
||||||
|
llvm::raw_string_ostream err_stream(err);
|
||||||
|
|
||||||
|
// verifyModule() returns true if the module is broken.
|
||||||
|
TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
|
||||||
|
<< "Invalid LLVM IR before optimizations:\n"
|
||||||
|
<< err_stream.str()
|
||||||
|
<< "\nThis probably indicates a bug in the HLO -> LLVM IR "
|
||||||
|
"lowering. "
|
||||||
|
"Rerun with --xla_dump_to to get the IR and looks for files "
|
||||||
|
"with "
|
||||||
|
"name containing: *"
|
||||||
|
<< FilenameFor(module, "", "") << "*";
|
||||||
|
}
|
||||||
|
GpuVersion gpu_version = GetGpuVersion(stream_exec);
|
||||||
|
return CompileTargetBinary(&module, llvm_module, gpu_version, stream_exec,
|
||||||
|
relocatable);
|
||||||
|
};
|
||||||
|
|
||||||
|
tensorflow::thread::ThreadPool* thread_pool = options.thread_pool;
|
||||||
|
if (!thread_pool) {
|
||||||
|
return compile_single_module(llvm_module.get(), /*relocatable=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test whether LinkModules is supported.
|
||||||
|
if (this->LinkModules(stream_exec, {}).status().code() ==
|
||||||
|
tensorflow::error::Code::UNIMPLEMENTED) {
|
||||||
|
return compile_single_module(llvm_module.get(), /*relocatable=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<llvm::Module>> llvm_modules;
|
||||||
|
int num_functions = 0;
|
||||||
|
for (llvm::Function& func : llvm_module->functions()) {
|
||||||
|
if (!func.isDeclaration() &&
|
||||||
|
func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) {
|
||||||
|
num_functions++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SplitModule(
|
||||||
|
std::move(llvm_module),
|
||||||
|
std::max<unsigned>(
|
||||||
|
1, std::min<unsigned>(thread_pool->NumThreads(), num_functions)),
|
||||||
|
[&](std::unique_ptr<llvm::Module> module) {
|
||||||
|
llvm_modules.push_back(std::move(module));
|
||||||
|
},
|
||||||
|
/*PreserveLocals=*/true);
|
||||||
|
|
||||||
|
std::vector<StatusOr<BackendCompileResult>> compile_results(
|
||||||
|
llvm_modules.size());
|
||||||
|
tensorflow::BlockingCounter counter(llvm_modules.size());
|
||||||
|
for (int i = 0; i < llvm_modules.size(); i++) {
|
||||||
|
thread_pool->Schedule([&compile_results, compile_single_module, i,
|
||||||
|
&llvm_modules, &counter] {
|
||||||
|
llvm::Module* original_module = llvm_modules[i].get();
|
||||||
|
llvm::LLVMContext context;
|
||||||
|
std::string buffer;
|
||||||
|
llvm::raw_string_ostream error(buffer);
|
||||||
|
llvm::DiagnosticPrinterRawOStream printer(error);
|
||||||
|
auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
|
||||||
|
void* Context) {
|
||||||
|
auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
|
||||||
|
diag_info.print(*printer);
|
||||||
|
};
|
||||||
|
context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
|
||||||
|
|
||||||
|
std::unique_ptr<llvm::Module> new_llvm_module;
|
||||||
|
{
|
||||||
|
std::string ir;
|
||||||
|
{
|
||||||
|
llvm::raw_string_ostream os(ir);
|
||||||
|
original_module->print(os, nullptr);
|
||||||
|
}
|
||||||
|
llvm::SMDiagnostic err;
|
||||||
|
new_llvm_module = llvm::parseAssemblyString(ir, err, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
compile_results[i] =
|
||||||
|
compile_single_module(new_llvm_module.get(), /*relocatable=*/true);
|
||||||
|
counter.DecrementCount();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
counter.Wait();
|
||||||
|
|
||||||
|
std::string ptx_snippets;
|
||||||
|
std::vector<std::vector<uint8>> submodule_compile_results;
|
||||||
|
for (auto& maybe_result : compile_results) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto result, maybe_result);
|
||||||
|
if (result.second.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ptx_snippets += result.first;
|
||||||
|
ptx_snippets += "\n";
|
||||||
|
submodule_compile_results.push_back(result.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::vector<uint8> backend_result,
|
||||||
|
this->LinkModules(stream_exec, std::move(submodule_compile_results)));
|
||||||
|
|
||||||
|
return std::make_pair(ptx_snippets, backend_result);
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) {
|
const CompileOptions& options) {
|
||||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
|
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
|
||||||
auto slow_compile_alarm = SlowCompilationAlarm();
|
auto slow_compile_alarm = SlowCompilationAlarm();
|
||||||
|
|
||||||
TF_RET_CHECK(stream_exec != nullptr);
|
TF_RET_CHECK(stream_exec != nullptr);
|
||||||
|
|
||||||
llvm::LLVMContext llvm_context;
|
llvm::LLVMContext llvm_context;
|
||||||
std::string buffer;
|
|
||||||
llvm::raw_string_ostream error(buffer);
|
|
||||||
llvm::DiagnosticPrinterRawOStream printer(error);
|
|
||||||
auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
|
|
||||||
void* Context) {
|
|
||||||
auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
|
|
||||||
diag_info.print(*printer);
|
|
||||||
};
|
|
||||||
llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
|
|
||||||
|
|
||||||
GpuDeviceInfo gpu_device_info;
|
GpuDeviceInfo gpu_device_info;
|
||||||
gpu_device_info.threads_per_block_limit =
|
gpu_device_info.threads_per_block_limit =
|
||||||
@ -724,34 +839,16 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
|||||||
|
|
||||||
llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false);
|
llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false);
|
||||||
|
|
||||||
{
|
|
||||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier");
|
|
||||||
|
|
||||||
std::string err;
|
|
||||||
llvm::raw_string_ostream err_stream(err);
|
|
||||||
|
|
||||||
// verifyModule() returns true if the module is broken.
|
|
||||||
TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
|
|
||||||
<< "Invalid LLVM IR before optimizations:\n"
|
|
||||||
<< err_stream.str()
|
|
||||||
<< "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
|
|
||||||
"Rerun with --xla_dump_to to get the IR and looks for files with "
|
|
||||||
"name containing: *"
|
|
||||||
<< FilenameFor(*module, "", "") << "*";
|
|
||||||
}
|
|
||||||
|
|
||||||
GpuVersion gpu_version = GetGpuVersion(stream_exec);
|
|
||||||
|
|
||||||
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
|
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
|
||||||
TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
|
TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
|
||||||
CompileTargetBinary(module.get(), llvm_module.get(),
|
CompileToTargetBinary(*module, std::move(llvm_module),
|
||||||
gpu_version, stream_exec));
|
stream_exec, options));
|
||||||
|
|
||||||
if (DumpingEnabledForHloModule(*module)) {
|
if (DumpingEnabledForHloModule(*module)) {
|
||||||
DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
|
DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
|
||||||
thunk_schedule->ToString());
|
thunk_schedule->ToString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GpuVersion gpu_version = GetGpuVersion(stream_exec);
|
||||||
auto* gpu_executable = new GpuExecutable(
|
auto* gpu_executable = new GpuExecutable(
|
||||||
backend_result.first, backend_result.second, gpu_version,
|
backend_result.first, backend_result.second, gpu_version,
|
||||||
std::move(thunk_schedule), std::move(module),
|
std::move(thunk_schedule), std::move(module),
|
||||||
|
@ -53,14 +53,13 @@ class GpuCompiler : public LLVMCompiler {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
|
|
||||||
StatusOr<
|
StatusOr<
|
||||||
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
|
||||||
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> hlo_module,
|
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> hlo_module,
|
||||||
se::StreamExecutor* executor,
|
se::StreamExecutor* executor, bool optimize,
|
||||||
se::DeviceMemoryAllocator* device_allocator,
|
const CompileOptions& options) override;
|
||||||
bool optimize) override;
|
|
||||||
|
|
||||||
Status OptimizeHloModule(HloModule* hlo_module,
|
Status OptimizeHloModule(HloModule* hlo_module,
|
||||||
se::StreamExecutor* stream_exec,
|
se::StreamExecutor* stream_exec,
|
||||||
@ -84,19 +83,23 @@ class GpuCompiler : public LLVMCompiler {
|
|||||||
|
|
||||||
virtual StatusOr<std::pair<std::string, std::vector<uint8>>>
|
virtual StatusOr<std::pair<std::string, std::vector<uint8>>>
|
||||||
CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module,
|
CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||||
GpuVersion gpu_version,
|
GpuVersion gpu_version, se::StreamExecutor* stream_exec,
|
||||||
se::StreamExecutor* stream_exec) = 0;
|
bool relocatable) = 0;
|
||||||
|
|
||||||
Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
|
Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
AotCompilationOptions const& options) override;
|
AotCompilationOptions const& options) override;
|
||||||
|
|
||||||
|
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileToTargetBinary(
|
||||||
|
const HloModule& module, std::unique_ptr<llvm::Module> llvm_module,
|
||||||
|
se::StreamExecutor* stream_exec, const CompileOptions& options);
|
||||||
|
|
||||||
se::Platform::Id PlatformId() const override { return platform_id_; }
|
se::Platform::Id PlatformId() const override { return platform_id_; }
|
||||||
|
|
||||||
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
|
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
|
||||||
@ -116,6 +119,12 @@ class GpuCompiler : public LLVMCompiler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
virtual StatusOr<std::vector<uint8>> LinkModules(
|
||||||
|
se::StreamExecutor* stream_exec,
|
||||||
|
std::vector<std::vector<uint8>> modules) {
|
||||||
|
return Unimplemented("LinkModules is not implemented.");
|
||||||
|
}
|
||||||
|
|
||||||
se::Platform::Id platform_id_;
|
se::Platform::Id platform_id_;
|
||||||
|
|
||||||
// The triple that represents our target.
|
// The triple that represents our target.
|
||||||
|
@ -51,6 +51,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
|
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
|
||||||
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
||||||
|
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
@ -299,7 +300,8 @@ StatusOr<std::pair<std::string, std::vector<uint8>>>
|
|||||||
NVPTXCompiler::CompileTargetBinary(const HloModule* module,
|
NVPTXCompiler::CompileTargetBinary(const HloModule* module,
|
||||||
llvm::Module* llvm_module,
|
llvm::Module* llvm_module,
|
||||||
GpuVersion gpu_version,
|
GpuVersion gpu_version,
|
||||||
se::StreamExecutor* stream_exec) {
|
se::StreamExecutor* stream_exec,
|
||||||
|
bool relocatable) {
|
||||||
std::pair<int, int> compute_capability =
|
std::pair<int, int> compute_capability =
|
||||||
absl::get<std::pair<int, int>>(gpu_version);
|
absl::get<std::pair<int, int>>(gpu_version);
|
||||||
|
|
||||||
@ -338,7 +340,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module,
|
|||||||
|
|
||||||
std::vector<uint8> cubin = CompileGpuAsmOrGetCachedResult(
|
std::vector<uint8> cubin = CompileGpuAsmOrGetCachedResult(
|
||||||
stream_exec, ptx, compute_capability.first, compute_capability.second,
|
stream_exec, ptx, compute_capability.first, compute_capability.second,
|
||||||
module->config());
|
module->config(), relocatable);
|
||||||
|
|
||||||
return std::pair<std::string, std::vector<uint8>>(std::move(ptx),
|
return std::pair<std::string, std::vector<uint8>>(std::move(ptx),
|
||||||
std::move(cubin));
|
std::move(cubin));
|
||||||
@ -346,7 +348,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module,
|
|||||||
|
|
||||||
std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
|
std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
|
||||||
se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
|
se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
|
||||||
int cc_minor, const HloModuleConfig& hlo_module_config) {
|
int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable) {
|
||||||
XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompileGpuAsmOrGetCachedResult");
|
XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompileGpuAsmOrGetCachedResult");
|
||||||
tensorflow::profiler::TraceMe activity(
|
tensorflow::profiler::TraceMe activity(
|
||||||
"PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo);
|
"PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||||
@ -361,7 +363,7 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
|
|||||||
tensorflow::mutex_lock lock(mutex_);
|
tensorflow::mutex_lock lock(mutex_);
|
||||||
std::tie(iter, inserted) = compilation_cache_.emplace(
|
std::tie(iter, inserted) = compilation_cache_.emplace(
|
||||||
std::piecewise_construct,
|
std::piecewise_construct,
|
||||||
std::forward_as_tuple(ptx, cc_major, cc_minor),
|
std::forward_as_tuple(ptx, cc_major, cc_minor, relocatable),
|
||||||
std::forward_as_tuple());
|
std::forward_as_tuple());
|
||||||
cache_ptx = &iter->first.ptx;
|
cache_ptx = &iter->first.ptx;
|
||||||
cache_value = &iter->second;
|
cache_value = &iter->second;
|
||||||
@ -375,9 +377,13 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
|
|||||||
if (inserted) {
|
if (inserted) {
|
||||||
CHECK(!cache_value->compilation_done);
|
CHECK(!cache_value->compilation_done);
|
||||||
if (!ptx.empty()) {
|
if (!ptx.empty()) {
|
||||||
StatusOr<std::vector<uint8>> maybe_cubin =
|
auto ptxas_config = PtxOptsFromConfig(hlo_module_config);
|
||||||
se::CompileGpuAsm(stream_exec->device_ordinal(), cache_ptx->c_str(),
|
if (relocatable) {
|
||||||
PtxOptsFromConfig(hlo_module_config));
|
ptxas_config.extra_flags.push_back("-c");
|
||||||
|
}
|
||||||
|
StatusOr<std::vector<uint8>> maybe_cubin = se::CompileGpuAsm(
|
||||||
|
stream_exec->device_ordinal(), cache_ptx->c_str(), ptxas_config);
|
||||||
|
|
||||||
if (maybe_cubin.ok()) {
|
if (maybe_cubin.ok()) {
|
||||||
cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie();
|
cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie();
|
||||||
VLOG(2) << "Compiled PTX size:" << ptx.size()
|
VLOG(2) << "Compiled PTX size:" << ptx.size()
|
||||||
@ -445,5 +451,17 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
|
|||||||
return cache_value->cubin_data;
|
return cache_value->cubin_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<uint8>> NVPTXCompiler::LinkModules(
|
||||||
|
se::StreamExecutor* stream_exec, std::vector<std::vector<uint8>> modules) {
|
||||||
|
std::vector<stream_executor::CubinOrPTXImage> images;
|
||||||
|
images.reserve(modules.size());
|
||||||
|
for (auto& module : modules) {
|
||||||
|
images.push_back({"", std::move(module)});
|
||||||
|
}
|
||||||
|
return LinkGpuAsm(static_cast<se::gpu::GpuContext*>(
|
||||||
|
stream_exec->implementation()->GpuContextHack()),
|
||||||
|
images);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -52,9 +52,14 @@ class NVPTXCompiler : public GpuCompiler {
|
|||||||
|
|
||||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
||||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
|
GpuVersion gpu_version, se::StreamExecutor* stream_exec,
|
||||||
|
bool relocatable) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
StatusOr<std::vector<uint8>> LinkModules(
|
||||||
|
se::StreamExecutor* stream_exec,
|
||||||
|
std::vector<std::vector<uint8>> modules) override;
|
||||||
|
|
||||||
tensorflow::mutex mutex_;
|
tensorflow::mutex mutex_;
|
||||||
|
|
||||||
// When compiling an HLO module, we need to find a path to the nvvm libdevice
|
// When compiling an HLO module, we need to find a path to the nvvm libdevice
|
||||||
@ -71,7 +76,7 @@ class NVPTXCompiler : public GpuCompiler {
|
|||||||
// compiled cubin. If compilation was unsuccessful, returns an empty vector.
|
// compiled cubin. If compilation was unsuccessful, returns an empty vector.
|
||||||
std::vector<uint8> CompileGpuAsmOrGetCachedResult(
|
std::vector<uint8> CompileGpuAsmOrGetCachedResult(
|
||||||
se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
|
se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
|
||||||
int cc_minor, const HloModuleConfig& hlo_module_config);
|
int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable);
|
||||||
|
|
||||||
// The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor}
|
// The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor}
|
||||||
// -> cubin so we don't recompile the same ptx twice. This is important for
|
// -> cubin so we don't recompile the same ptx twice. This is important for
|
||||||
@ -86,24 +91,32 @@ class NVPTXCompiler : public GpuCompiler {
|
|||||||
// If compiling the ptx fails, we return an empty cubin, cross our fingers,
|
// If compiling the ptx fails, we return an empty cubin, cross our fingers,
|
||||||
// and leave compilation up to the driver.
|
// and leave compilation up to the driver.
|
||||||
struct CompilationCacheKey {
|
struct CompilationCacheKey {
|
||||||
CompilationCacheKey(std::string ptx, int cc_major, int cc_minor)
|
CompilationCacheKey(std::string ptx, int cc_major, int cc_minor,
|
||||||
: ptx(std::move(ptx)), cc_major(cc_major), cc_minor(cc_minor) {}
|
bool relocatable)
|
||||||
|
: ptx(std::move(ptx)),
|
||||||
|
cc_major(cc_major),
|
||||||
|
cc_minor(cc_minor),
|
||||||
|
relocatable(relocatable) {}
|
||||||
string ptx;
|
string ptx;
|
||||||
int cc_major;
|
int cc_major;
|
||||||
int cc_minor;
|
int cc_minor;
|
||||||
|
bool relocatable;
|
||||||
};
|
};
|
||||||
struct CompilationCacheHash {
|
struct CompilationCacheHash {
|
||||||
size_t operator()(const CompilationCacheKey& key) const {
|
size_t operator()(const CompilationCacheKey& key) const {
|
||||||
return tensorflow::Hash64Combine(
|
return tensorflow::Hash64Combine(
|
||||||
tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx), key.cc_major),
|
tensorflow::Hash64Combine(
|
||||||
key.cc_minor);
|
tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx),
|
||||||
|
key.cc_major),
|
||||||
|
key.cc_minor),
|
||||||
|
key.relocatable);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct CompilationCacheEq {
|
struct CompilationCacheEq {
|
||||||
size_t operator()(const CompilationCacheKey& a,
|
size_t operator()(const CompilationCacheKey& a,
|
||||||
const CompilationCacheKey& b) const {
|
const CompilationCacheKey& b) const {
|
||||||
return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor &&
|
return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor &&
|
||||||
a.ptx == b.ptx;
|
a.ptx == b.ptx && a.relocatable == b.relocatable;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct CompilationCacheValue {
|
struct CompilationCacheValue {
|
||||||
|
@ -95,7 +95,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
|
||||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/,
|
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/,
|
||||||
se::DeviceMemoryAllocator* /*device_allocator*/) {
|
const CompileOptions& /*options*/) {
|
||||||
VLOG(1) << "Run hlo passes on graph " << hlo_module->name();
|
VLOG(1) << "Run hlo passes on graph " << hlo_module->name();
|
||||||
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
|
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
|
||||||
return std::move(hlo_module);
|
return std::move(hlo_module);
|
||||||
@ -103,7 +103,7 @@ StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
||||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* /*device_allocator*/) {
|
const CompileOptions& /*options*/) {
|
||||||
TF_RET_CHECK(stream_exec != nullptr);
|
TF_RET_CHECK(stream_exec != nullptr);
|
||||||
|
|
||||||
VLOG(1) << "Run backend " << hlo_module->name();
|
VLOG(1) << "Run backend " << hlo_module->name();
|
||||||
@ -128,7 +128,7 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
|||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
|
||||||
std::unique_ptr<HloModuleGroup> module_group,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) {
|
const CompileOptions& options) {
|
||||||
if (module_group->empty()) {
|
if (module_group->empty()) {
|
||||||
return std::vector<std::unique_ptr<Executable>>();
|
return std::vector<std::unique_ptr<Executable>>();
|
||||||
}
|
}
|
||||||
@ -141,12 +141,10 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
|
|||||||
"Unexpected number of StreamExecutor's.");
|
"Unexpected number of StreamExecutor's.");
|
||||||
}
|
}
|
||||||
auto hlo_modules = module_group->ConsumeModules();
|
auto hlo_modules = module_group->ConsumeModules();
|
||||||
TF_ASSIGN_OR_RETURN(auto module,
|
TF_ASSIGN_OR_RETURN(auto module, RunHloPasses(std::move(hlo_modules[0]),
|
||||||
RunHloPasses(std::move(hlo_modules[0]), stream_exec[0][0],
|
stream_exec[0][0], options));
|
||||||
device_allocator));
|
TF_ASSIGN_OR_RETURN(auto executable, RunBackend(std::move(module),
|
||||||
TF_ASSIGN_OR_RETURN(
|
stream_exec[0][0], options));
|
||||||
auto executable,
|
|
||||||
RunBackend(std::move(module), stream_exec[0][0], device_allocator));
|
|
||||||
std::vector<std::unique_ptr<Executable>> ret;
|
std::vector<std::unique_ptr<Executable>> ret;
|
||||||
ret.push_back(std::move(executable));
|
ret.push_back(std::move(executable));
|
||||||
return std::move(ret);
|
return std::move(ret);
|
||||||
|
@ -45,14 +45,14 @@ class InterpreterCompiler : public Compiler {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||||
std::unique_ptr<HloModuleGroup> module_group,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
|
@ -24,7 +24,7 @@ namespace xla {
|
|||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
|
||||||
std::unique_ptr<HloModuleGroup> module_group,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||||
se::DeviceMemoryAllocator* device_allocator) {
|
const CompileOptions& options) {
|
||||||
// Tensorflow tries to enable the following behaviors in all its threads:
|
// Tensorflow tries to enable the following behaviors in all its threads:
|
||||||
//
|
//
|
||||||
// - Denormals are zero (DAZ): roughly, operations treat denormal floats as
|
// - Denormals are zero (DAZ): roughly, operations treat denormal floats as
|
||||||
@ -48,10 +48,10 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
|
|||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(modules[i],
|
TF_ASSIGN_OR_RETURN(modules[i],
|
||||||
RunHloPasses(std::move(modules[i]), stream_execs[i][0],
|
RunHloPasses(std::move(modules[i]), stream_execs[i][0],
|
||||||
device_allocator));
|
options.device_allocator));
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||||
RunBackend(std::move(modules[i]), stream_execs[i][0],
|
RunBackend(std::move(modules[i]), stream_execs[i][0],
|
||||||
device_allocator));
|
options.device_allocator));
|
||||||
result.push_back(std::move(executable));
|
result.push_back(std::move(executable));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,13 +66,14 @@ class LLVMCompiler : public Compiler {
|
|||||||
// std::unique_ptr<HloModule> module,
|
// std::unique_ptr<HloModule> module,
|
||||||
// se::StreamExecutor* stream_exec,
|
// se::StreamExecutor* stream_exec,
|
||||||
// se::DeviceMemoryAllocator* device_allocator)
|
// se::DeviceMemoryAllocator* device_allocator)
|
||||||
|
using Compiler::Compile;
|
||||||
using Compiler::RunBackend;
|
using Compiler::RunBackend;
|
||||||
using Compiler::RunHloPasses;
|
using Compiler::RunHloPasses;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||||
std::unique_ptr<HloModuleGroup> module_group,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
ModuleHook user_pre_optimization_hook_;
|
ModuleHook user_pre_optimization_hook_;
|
||||||
|
@ -190,11 +190,12 @@ LocalService::CompileExecutables(
|
|||||||
// single partition computations are built using `BuildExecutables`, fix it,
|
// single partition computations are built using `BuildExecutables`, fix it,
|
||||||
// and remove this special case (provided the performance if similar).
|
// and remove this special case (provided the performance if similar).
|
||||||
if (build_options.num_partitions() == 1) {
|
if (build_options.num_partitions() == 1) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||||
std::unique_ptr<Executable> executable,
|
BuildExecutable(proto, std::move(module_config),
|
||||||
BuildExecutable(proto, std::move(module_config), execute_backend_.get(),
|
execute_backend_.get(), executor,
|
||||||
executor, build_options.device_allocator(),
|
{build_options.device_allocator(),
|
||||||
build_options.run_backend_only()));
|
build_options.compile_thread_pool()},
|
||||||
|
build_options.run_backend_only()));
|
||||||
std::vector<std::unique_ptr<Executable>> executables;
|
std::vector<std::unique_ptr<Executable>> executables;
|
||||||
executables.push_back(std::move(executable));
|
executables.push_back(std::move(executable));
|
||||||
return executables;
|
return executables;
|
||||||
@ -206,10 +207,12 @@ LocalService::CompileExecutables(
|
|||||||
std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
|
std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
|
||||||
executor);
|
executor);
|
||||||
|
|
||||||
return BuildExecutables({&proto}, std::move(module_configs),
|
return BuildExecutables(
|
||||||
execute_backend_.get(), {executors},
|
/*module_protos=*/{&proto}, std::move(module_configs),
|
||||||
build_options.device_allocator(),
|
execute_backend_.get(), {executors},
|
||||||
build_options.run_backend_only());
|
Compiler::CompileOptions{build_options.device_allocator(),
|
||||||
|
build_options.compile_thread_pool()},
|
||||||
|
build_options.run_backend_only());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,28 +28,24 @@ bool IsUnimplemented(StatusOr<T>& result) {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> FailoverCompiler::RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> FailoverCompiler::RunHloPasses(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) {
|
const CompileOptions& options) {
|
||||||
auto result =
|
auto result = primary_->RunHloPasses(module->Clone(), stream_exec, options);
|
||||||
primary_->RunHloPasses(module->Clone(), stream_exec, device_allocator);
|
|
||||||
if (IsUnimplemented(result)) {
|
if (IsUnimplemented(result)) {
|
||||||
VLOG(2) << "RunHloPasses resulted in " << result.status()
|
VLOG(2) << "RunHloPasses resulted in " << result.status()
|
||||||
<< ", falling back to secondary backend";
|
<< ", falling back to secondary backend";
|
||||||
return secondary_->RunHloPasses(std::move(module), stream_exec,
|
return secondary_->RunHloPasses(std::move(module), stream_exec, options);
|
||||||
device_allocator);
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> FailoverCompiler::RunBackend(
|
StatusOr<std::unique_ptr<Executable>> FailoverCompiler::RunBackend(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) {
|
const CompileOptions& options) {
|
||||||
auto result =
|
auto result = primary_->RunBackend(module->Clone(), stream_exec, options);
|
||||||
primary_->RunBackend(module->Clone(), stream_exec, device_allocator);
|
|
||||||
if (IsUnimplemented(result)) {
|
if (IsUnimplemented(result)) {
|
||||||
VLOG(2) << "RunBackend resulted in " << result.status()
|
VLOG(2) << "RunBackend resulted in " << result.status()
|
||||||
<< ", falling back to secondary backend";
|
<< ", falling back to secondary backend";
|
||||||
return secondary_->RunBackend(std::move(module), stream_exec,
|
return secondary_->RunBackend(std::move(module), stream_exec, options);
|
||||||
device_allocator);
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -57,7 +53,7 @@ StatusOr<std::unique_ptr<Executable>> FailoverCompiler::RunBackend(
|
|||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
|
||||||
std::unique_ptr<HloModuleGroup> module_group,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||||
se::DeviceMemoryAllocator* device_allocator) {
|
const CompileOptions& options) {
|
||||||
std::vector<std::unique_ptr<Executable>> result;
|
std::vector<std::unique_ptr<Executable>> result;
|
||||||
std::vector<std::unique_ptr<HloModule>> modules =
|
std::vector<std::unique_ptr<HloModule>> modules =
|
||||||
module_group->ConsumeModules();
|
module_group->ConsumeModules();
|
||||||
@ -67,17 +63,15 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
|
|||||||
return Unimplemented(
|
return Unimplemented(
|
||||||
"Model partitioning not implemented for the failover compiler!");
|
"Model partitioning not implemented for the failover compiler!");
|
||||||
}
|
}
|
||||||
auto executable = [stream_execs, device_allocator, i,
|
auto executable = [stream_execs, &options, i,
|
||||||
this](std::unique_ptr<HloModule> module)
|
this](std::unique_ptr<HloModule> module)
|
||||||
-> StatusOr<std::unique_ptr<Executable>> {
|
-> StatusOr<std::unique_ptr<Executable>> {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(auto processed_module,
|
||||||
auto processed_module,
|
primary_->RunHloPasses(std::move(module),
|
||||||
primary_->RunHloPasses(std::move(module), stream_execs[i][0],
|
stream_execs[i][0], options));
|
||||||
device_allocator));
|
TF_ASSIGN_OR_RETURN(auto result,
|
||||||
TF_ASSIGN_OR_RETURN(
|
primary_->RunBackend(std::move(processed_module),
|
||||||
auto result,
|
stream_execs[i][0], options));
|
||||||
primary_->RunBackend(std::move(processed_module), stream_execs[i][0],
|
|
||||||
device_allocator));
|
|
||||||
return result;
|
return result;
|
||||||
}(modules[i]->Clone());
|
}(modules[i]->Clone());
|
||||||
|
|
||||||
@ -85,13 +79,11 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
|
|||||||
VLOG(2) << "Compile resulted in " << executable.status()
|
VLOG(2) << "Compile resulted in " << executable.status()
|
||||||
<< ", falling back to secondary backend";
|
<< ", falling back to secondary backend";
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
modules[i],
|
modules[i], secondary_->RunHloPasses(std::move(modules[i]),
|
||||||
secondary_->RunHloPasses(std::move(modules[i]), stream_execs[i][0],
|
stream_execs[i][0], options));
|
||||||
device_allocator));
|
TF_ASSIGN_OR_RETURN(executable,
|
||||||
TF_ASSIGN_OR_RETURN(
|
secondary_->RunBackend(std::move(modules[i]),
|
||||||
executable,
|
stream_execs[i][0], options));
|
||||||
secondary_->RunBackend(std::move(modules[i]), stream_execs[i][0],
|
|
||||||
device_allocator));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!executable.ok()) {
|
if (!executable.ok()) {
|
||||||
|
@ -51,16 +51,16 @@ class FailoverCompiler final : public Compiler {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||||
std::unique_ptr<HloModuleGroup> module_group,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
|
@ -87,16 +87,16 @@ class MlirCompilerImpl : public MlirCompiler {
|
|||||||
public:
|
public:
|
||||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||||
std::unique_ptr<HloModuleGroup> module_group,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||||
se::DeviceMemoryAllocator* device_allocator) override;
|
const CompileOptions& options) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
@ -155,12 +155,12 @@ std::string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> MlirCompilerImpl::RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> MlirCompilerImpl::RunHloPasses(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) {
|
const CompileOptions& options) {
|
||||||
// Until we find a reason to do something different, run the same passes
|
// Until we find a reason to do something different, run the same passes
|
||||||
// that the normal GPU backend runs.
|
// that the normal GPU backend runs.
|
||||||
gpu::NVPTXCompiler xla_compiler;
|
gpu::NVPTXCompiler xla_compiler;
|
||||||
TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec,
|
TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec,
|
||||||
device_allocator));
|
options.device_allocator));
|
||||||
TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get()));
|
TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get()));
|
||||||
|
|
||||||
return std::move(module);
|
return std::move(module);
|
||||||
@ -454,7 +454,7 @@ StatusOr<std::unique_ptr<gpu::KernelThunk>> TransformKernelToXlaThunk(
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
|
StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* device_allocator) {
|
const CompileOptions& options) {
|
||||||
// Determine the HLO schedule, which is an ordering of HLO instructions. This
|
// Determine the HLO schedule, which is an ordering of HLO instructions. This
|
||||||
// is used by buffer assignment to enable buffer reuse, and the same ordering
|
// is used by buffer assignment to enable buffer reuse, and the same ordering
|
||||||
// must also be used to determine the thunk launch schedule.
|
// must also be used to determine the thunk launch schedule.
|
||||||
@ -595,7 +595,7 @@ StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
|
|||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(
|
||||||
std::unique_ptr<HloModuleGroup> module_group,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||||
se::DeviceMemoryAllocator* device_allocator) {
|
const CompileOptions& options) {
|
||||||
return Unimplemented("Not yet implemented in MLIR compiler");
|
return Unimplemented("Not yet implemented in MLIR compiler");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -357,7 +357,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
|||||||
const std::vector<const HloModuleProto*>& module_protos,
|
const std::vector<const HloModuleProto*>& module_protos,
|
||||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
|
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
|
||||||
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
|
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
|
||||||
se::DeviceMemoryAllocator* device_allocator, bool run_backend_only) {
|
const Compiler::CompileOptions& options, bool run_backend_only) {
|
||||||
VLOG(1) << StrFormat("BuildExecutable on service %p", this);
|
VLOG(1) << StrFormat("BuildExecutable on service %p", this);
|
||||||
|
|
||||||
// Dump computation proto state if flag is set.
|
// Dump computation proto state if flag is set.
|
||||||
@ -387,17 +387,15 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
|||||||
|
|
||||||
std::vector<std::unique_ptr<Executable>> executables;
|
std::vector<std::unique_ptr<Executable>> executables;
|
||||||
if (!run_backend_only) {
|
if (!run_backend_only) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(executables, backend->compiler()->Compile(
|
||||||
executables,
|
std::move(module_group),
|
||||||
backend->compiler()->Compile(std::move(module_group),
|
std::move(executors), options));
|
||||||
std::move(executors), device_allocator));
|
|
||||||
} else {
|
} else {
|
||||||
auto modules = module_group->ConsumeModules();
|
auto modules = module_group->ConsumeModules();
|
||||||
for (std::unique_ptr<HloModule>& module : modules) {
|
for (std::unique_ptr<HloModule>& module : modules) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||||
std::unique_ptr<Executable> executable,
|
backend->compiler()->RunBackend(
|
||||||
backend->compiler()->RunBackend(std::move(module), executors[0][0],
|
std::move(module), executors[0][0], options));
|
||||||
device_allocator));
|
|
||||||
executables.push_back(std::move(executable));
|
executables.push_back(std::move(executable));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -710,7 +708,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
|
|||||||
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
|
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
|
||||||
BuildExecutables(module_protos, std::move(module_configs),
|
BuildExecutables(module_protos, std::move(module_configs),
|
||||||
execute_backend_.get(), all_executors,
|
execute_backend_.get(), all_executors,
|
||||||
/*device_allocator=*/nullptr));
|
{/*device_allocator=*/nullptr}));
|
||||||
std::vector<Executable*> executable_ptrs;
|
std::vector<Executable*> executable_ptrs;
|
||||||
executable_ptrs.reserve(executables.size());
|
executable_ptrs.reserve(executables.size());
|
||||||
for (const auto& executable : executables) {
|
for (const auto& executable : executables) {
|
||||||
@ -810,7 +808,7 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
|
|||||||
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||||
const HloModuleProto& module_proto,
|
const HloModuleProto& module_proto,
|
||||||
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
|
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
|
||||||
se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator,
|
se::StreamExecutor* executor, const Compiler::CompileOptions& options,
|
||||||
bool run_backend_only) {
|
bool run_backend_only) {
|
||||||
VLOG(1) << StrFormat(
|
VLOG(1) << StrFormat(
|
||||||
"BuildExecutable on service %p with serialized module proto: %s", this,
|
"BuildExecutable on service %p with serialized module proto: %s", this,
|
||||||
@ -822,14 +820,13 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
|||||||
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
|
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
|
||||||
|
|
||||||
if (!run_backend_only) {
|
if (!run_backend_only) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
|
||||||
module, backend->compiler()->RunHloPasses(std::move(module), executor,
|
std::move(module), executor, options));
|
||||||
device_allocator));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
TF_ASSIGN_OR_RETURN(
|
||||||
backend->compiler()->RunBackend(
|
std::unique_ptr<Executable> executable,
|
||||||
std::move(module), executor, device_allocator));
|
backend->compiler()->RunBackend(std::move(module), executor, options));
|
||||||
|
|
||||||
const auto& debug_opts = module_config->debug_options();
|
const auto& debug_opts = module_config->debug_options();
|
||||||
if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) &&
|
if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) &&
|
||||||
@ -875,7 +872,7 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
|
|||||||
BuildExecutable(arg->computation(), std::move(module_config),
|
BuildExecutable(arg->computation(), std::move(module_config),
|
||||||
execute_backend_.get(),
|
execute_backend_.get(),
|
||||||
execute_backend_->default_stream_executor(),
|
execute_backend_->default_stream_executor(),
|
||||||
/*device_allocator=*/nullptr));
|
{/*device_allocator=*/nullptr}));
|
||||||
|
|
||||||
*result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
|
*result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
|
||||||
|
|
||||||
|
@ -235,8 +235,7 @@ class Service : public ServiceInterface {
|
|||||||
StatusOr<std::unique_ptr<Executable>> BuildExecutable(
|
StatusOr<std::unique_ptr<Executable>> BuildExecutable(
|
||||||
const HloModuleProto& module_proto,
|
const HloModuleProto& module_proto,
|
||||||
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
|
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
|
||||||
se::StreamExecutor* executor,
|
se::StreamExecutor* executor, const Compiler::CompileOptions& options,
|
||||||
se::DeviceMemoryAllocator* device_allocator = nullptr,
|
|
||||||
bool run_backend_only = false);
|
bool run_backend_only = false);
|
||||||
|
|
||||||
// Same as BuildExecutable() above, but builds a list of Executables for the
|
// Same as BuildExecutable() above, but builds a list of Executables for the
|
||||||
@ -245,8 +244,7 @@ class Service : public ServiceInterface {
|
|||||||
const std::vector<const HloModuleProto*>& module_protos,
|
const std::vector<const HloModuleProto*>& module_protos,
|
||||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
|
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
|
||||||
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
|
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
|
||||||
se::DeviceMemoryAllocator* device_allocator,
|
const Compiler::CompileOptions& options, bool run_backend_only = false);
|
||||||
bool run_backend_only = false);
|
|
||||||
|
|
||||||
// Runs the given executable with the given arguments and register the result
|
// Runs the given executable with the given arguments and register the result
|
||||||
// in the allocation tracker. The handle of the result from the tracker is
|
// in the allocation tracker. The handle of the result from the tracker is
|
||||||
|
@ -102,27 +102,8 @@ bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a sharding where each tuple element is chosen as the more specific
|
|
||||||
// one of the corresponding elements in a and b. Requires a an b to have the
|
|
||||||
// same tuple nesting.
|
|
||||||
HloSharding MergeForMoreSpecificSharding(const HloSharding& a,
|
|
||||||
const HloSharding& b) {
|
|
||||||
if (a.IsTuple()) {
|
|
||||||
HloSharding result = a;
|
|
||||||
CHECK(b.IsTuple());
|
|
||||||
CHECK_EQ(a.tuple_elements().size(), b.tuple_elements().size());
|
|
||||||
for (int64 i = 0; i < result.tuple_elements().size(); ++i) {
|
|
||||||
result.tuple_elements()[i] = MergeForMoreSpecificSharding(
|
|
||||||
a.tuple_elements()[i], b.tuple_elements()[i]);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
return IsShardingMoreSpecific(a, b) ? a : b;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tries to refine `to_merge` by combining with `old`. Returns if the final
|
// Tries to refine `to_merge` by combining with `old`. Returns if the final
|
||||||
// `to_merge` is more specific than `old`. May combine partial sharding in
|
// `to_merge` is more specific than `old`.
|
||||||
// addition to MergeForMoreSpecificSharding().
|
|
||||||
bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
|
bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
|
||||||
bool may_combine_partial_sharding) {
|
bool may_combine_partial_sharding) {
|
||||||
if (old.IsTuple()) {
|
if (old.IsTuple()) {
|
||||||
@ -1093,8 +1074,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
}
|
}
|
||||||
auto sharding = instruction->operand(0)->sharding();
|
auto sharding = instruction->operand(0)->sharding();
|
||||||
if (instruction->has_sharding()) {
|
if (instruction->has_sharding()) {
|
||||||
sharding =
|
MergeSharding(instruction->sharding(), &sharding,
|
||||||
MergeForMoreSpecificSharding(sharding, instruction->sharding());
|
may_combine_partial_sharding);
|
||||||
}
|
}
|
||||||
return MaybeImproveInstructionSharding(std::move(sharding), instruction,
|
return MaybeImproveInstructionSharding(std::move(sharding), instruction,
|
||||||
may_combine_partial_sharding);
|
may_combine_partial_sharding);
|
||||||
@ -1320,6 +1301,12 @@ absl::optional<HloSharding> GetShardingFromUser(
|
|||||||
return hlo_sharding_util::ReshapeSharding(
|
return hlo_sharding_util::ReshapeSharding(
|
||||||
user.shape(), instruction.shape(), user.sharding());
|
user.shape(), instruction.shape(), user.sharding());
|
||||||
}
|
}
|
||||||
|
case HloOpcode::kPad: {
|
||||||
|
if (&instruction != user.operand(0)) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
return user.sharding();
|
||||||
|
}
|
||||||
case HloOpcode::kSlice: {
|
case HloOpcode::kSlice: {
|
||||||
return user.sharding();
|
return user.sharding();
|
||||||
}
|
}
|
||||||
@ -1673,8 +1660,10 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
|||||||
// If instruction is a while, or the root or a parameter of a while body,
|
// If instruction is a while, or the root or a parameter of a while body,
|
||||||
// then propagate its sharding to the while instruction, to its body root,
|
// then propagate its sharding to the while instruction, to its body root,
|
||||||
// and to its condition parameter.
|
// and to its condition parameter.
|
||||||
std::function<void(HloInstruction*)> maybe_computation_propagation =
|
std::function<void(HloInstruction*, absl::flat_hash_set<HloInstruction*>*)>
|
||||||
[&](HloInstruction* instruction) {
|
maybe_computation_propagation = [&](HloInstruction* instruction,
|
||||||
|
absl::flat_hash_set<HloInstruction*>*
|
||||||
|
changed) {
|
||||||
auto propagate_to_instruction = [&](HloInstruction* search_inst) {
|
auto propagate_to_instruction = [&](HloInstruction* search_inst) {
|
||||||
auto related_instructions = get_related_instructions(search_inst);
|
auto related_instructions = get_related_instructions(search_inst);
|
||||||
if (absl::c_count(related_instructions, instruction)) {
|
if (absl::c_count(related_instructions, instruction)) {
|
||||||
@ -1683,7 +1672,8 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
|||||||
inst->sharding() != instruction->sharding()) {
|
inst->sharding() != instruction->sharding()) {
|
||||||
VLOG(2) << "Add computation sharding: " << inst->name();
|
VLOG(2) << "Add computation sharding: " << inst->name();
|
||||||
inst->set_sharding(instruction->sharding());
|
inst->set_sharding(instruction->sharding());
|
||||||
maybe_computation_propagation(inst);
|
changed->insert(inst);
|
||||||
|
maybe_computation_propagation(inst, changed);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1785,6 +1775,14 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
|||||||
for (const HloInstruction* instruction : instructions) {
|
for (const HloInstruction* instruction : instructions) {
|
||||||
already_sharded_counter += (instruction->has_sharding() ? 1 : 0);
|
already_sharded_counter += (instruction->has_sharding() ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
auto clear_cache = [&](HloInstruction* hlo) {
|
||||||
|
for (auto operand : hlo->operands()) {
|
||||||
|
already_inferred_from_users.erase(operand);
|
||||||
|
}
|
||||||
|
for (auto user : hlo->users()) {
|
||||||
|
already_inferred_from_operands.erase(user);
|
||||||
|
}
|
||||||
|
};
|
||||||
// First iterate the HLO graph in post order taking shardings from
|
// First iterate the HLO graph in post order taking shardings from
|
||||||
// operands.
|
// operands.
|
||||||
for (HloInstruction* instruction : instructions) {
|
for (HloInstruction* instruction : instructions) {
|
||||||
@ -1799,12 +1797,11 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
|||||||
any_changed = true;
|
any_changed = true;
|
||||||
VLOG(2) << "Add sharding (forward-pass): "
|
VLOG(2) << "Add sharding (forward-pass): "
|
||||||
<< instruction->ToString();
|
<< instruction->ToString();
|
||||||
maybe_computation_propagation(instruction);
|
absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
|
||||||
for (auto operand : instruction->operands()) {
|
maybe_computation_propagation(instruction, &changed_in_comp_prop);
|
||||||
already_inferred_from_users.erase(operand);
|
clear_cache(instruction);
|
||||||
}
|
for (auto hlo : changed_in_comp_prop) {
|
||||||
for (auto user : instruction->users()) {
|
clear_cache(hlo);
|
||||||
already_inferred_from_operands.erase(user);
|
|
||||||
}
|
}
|
||||||
changed_last_iter = true;
|
changed_last_iter = true;
|
||||||
}
|
}
|
||||||
@ -1823,12 +1820,11 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
|||||||
++inferred_from_user_counter;
|
++inferred_from_user_counter;
|
||||||
any_changed = true;
|
any_changed = true;
|
||||||
VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString();
|
VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString();
|
||||||
maybe_computation_propagation(*it);
|
absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
|
||||||
for (auto operand : (*it)->operands()) {
|
maybe_computation_propagation(*it, &changed_in_comp_prop);
|
||||||
already_inferred_from_users.erase(operand);
|
clear_cache(*it);
|
||||||
}
|
for (auto hlo : changed_in_comp_prop) {
|
||||||
for (auto user : (*it)->users()) {
|
clear_cache(hlo);
|
||||||
already_inferred_from_operands.erase(user);
|
|
||||||
}
|
}
|
||||||
changed_last_iter = true;
|
changed_last_iter = true;
|
||||||
}
|
}
|
||||||
|
@ -514,6 +514,26 @@ ENTRY %pad {
|
|||||||
op::Sharding("{devices=[2,2]0,1,2,3}"));
|
op::Sharding("{devices=[2,2]0,1,2,3}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ShardingPropagationTest, PadBackwardPass) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
ENTRY %pad {
|
||||||
|
%input = f32[11,17]{1,0} parameter(0)
|
||||||
|
%copy = f32[11,17]{1,0} copy(%input)
|
||||||
|
%pad_value = f32[] parameter(1)
|
||||||
|
%pad = f32[27,51]{1,0} pad(%copy, %pad_value), padding=2_4_1x1_1_2,
|
||||||
|
sharding={devices=[2,2]0,1,2,3}
|
||||||
|
ROOT %result = f32[27,51]{1,0} copy(%pad)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
|
ShardingPropagation().Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
EXPECT_THAT(FindInstruction(module.get(), "copy"),
|
||||||
|
op::Sharding("{devices=[2,2]0,1,2,3}"));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ShardingPropagationTest, PartialReplicatedPadForwardPass) {
|
TEST_F(ShardingPropagationTest, PartialReplicatedPadForwardPass) {
|
||||||
const char* const hlo_string = R"(
|
const char* const hlo_string = R"(
|
||||||
HloModule module
|
HloModule module
|
||||||
@ -856,40 +876,41 @@ TEST_F(ShardingPropagationTest, While) {
|
|||||||
HloModule module
|
HloModule module
|
||||||
|
|
||||||
%cond {
|
%cond {
|
||||||
%vars.cond = (u32[], f32[10]{0}) parameter(0)
|
%vars.cond = (u32[], f32[10,10]) parameter(0)
|
||||||
%count.cond = u32[] get-tuple-element((u32[], f32[10]{0}) %vars.cond), index=0
|
%count.cond = u32[] get-tuple-element((u32[], f32[10,10]) %vars.cond), index=0
|
||||||
%limit = u32[] constant(10)
|
%limit = u32[] constant(10)
|
||||||
ROOT %lt = pred[] compare(u32[] %count.cond, u32[] %limit), direction=LT
|
ROOT %lt = pred[] compare(u32[] %count.cond, u32[] %limit), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
%body {
|
%body {
|
||||||
%vars = (u32[], f32[10]{0}) parameter(0)
|
%vars = (u32[], f32[10,10]) parameter(0)
|
||||||
%count = u32[] get-tuple-element(%vars), index=0
|
%count = u32[] get-tuple-element(%vars), index=0
|
||||||
%acc = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %vars), index=1
|
%acc = f32[10,10] get-tuple-element((u32[], f32[10,10]) %vars), index=1
|
||||||
|
|
||||||
%one = u32[] constant(1)
|
%one = u32[] constant(1)
|
||||||
%count.1 = u32[] add(u32[] %count, u32[] %one), sharding={replicated}
|
%count.1 = u32[] add(u32[] %count, u32[] %one), sharding={replicated}
|
||||||
%acc.1 = f32[10]{0} add(f32[10]{0} %acc, f32[10]{0} %acc)
|
%acc.1 = f32[10,10] add(f32[10,10] %acc, f32[10,10] %acc)
|
||||||
ROOT %tuple = (u32[], f32[10]{0}) tuple(u32[] %count.1, f32[10]{0} %acc.1)
|
ROOT %tuple = (u32[], f32[10,10]) tuple(u32[] %count.1, f32[10,10] %acc.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY %entry {
|
ENTRY %entry {
|
||||||
%p0 = f32[10]{0} parameter(0)
|
%p0 = f32[10,10] parameter(0)
|
||||||
%p0.copy = f32[10]{0} copy(f32[10]{0} %p0)
|
%p0.copy = f32[10,10] copy(f32[10,10] %p0)
|
||||||
%p1 = f32[10]{0} parameter(1)
|
%p1 = f32[10,10] parameter(1)
|
||||||
%zero = u32[] constant(0)
|
%zero = u32[] constant(0)
|
||||||
%init = (u32[], f32[10]{0}) tuple(u32[] %zero, f32[10]{0} %p0.copy)
|
%init = (u32[], f32[10,10]) tuple(u32[] %zero, f32[10,10] %p0.copy)
|
||||||
%while = (u32[], f32[10]{0}) while((u32[], f32[10]{0}) %init),
|
%while = (u32[], f32[10,10]) while((u32[], f32[10,10]) %init),
|
||||||
body=%body, condition=%cond
|
body=%body, condition=%cond
|
||||||
%res = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %while), index=1
|
%res = f32[10,10] get-tuple-element((u32[], f32[10,10]) %while), index=1
|
||||||
%prev = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %init), index=1
|
%prev = f32[10,10] get-tuple-element((u32[], f32[10,10]) %init), index=1
|
||||||
%res.1 = f32[10]{0} multiply(f32[10]{0} %res, %prev)
|
%res.1 = f32[10,10] multiply(f32[10,10] %res, %prev)
|
||||||
ROOT %res_tuple = (f32[10]{0}) tuple(f32[10]{0} %res.1)
|
ROOT %res_tuple = (f32[10,10]) tuple(f32[10,10] %res.1)
|
||||||
})";
|
})";
|
||||||
|
|
||||||
auto while_is_sharded = [this](HloModule* module,
|
auto while_is_sharded = [this](HloModule* module,
|
||||||
const HloSharding& sharding) {
|
const HloSharding& sharding) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardingPropagation().Run(module));
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
|
ShardingPropagation(/*is_spmd=*/true).Run(module));
|
||||||
EXPECT_TRUE(changed);
|
EXPECT_TRUE(changed);
|
||||||
auto while_instr = FindInstruction(module, "while");
|
auto while_instr = FindInstruction(module, "while");
|
||||||
EXPECT_NE(nullptr, while_instr);
|
EXPECT_NE(nullptr, while_instr);
|
||||||
@ -911,7 +932,7 @@ ENTRY %entry {
|
|||||||
auto body_root = FindInstruction(module.get(), "tuple");
|
auto body_root = FindInstruction(module.get(), "tuple");
|
||||||
EXPECT_NE(nullptr, body_root);
|
EXPECT_NE(nullptr, body_root);
|
||||||
auto sharding =
|
auto sharding =
|
||||||
ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie();
|
ParseSharding("{{replicated}, {devices=[2,1]0,1}}").ConsumeValueOrDie();
|
||||||
body_root->set_sharding(sharding);
|
body_root->set_sharding(sharding);
|
||||||
while_is_sharded(module.get(), sharding);
|
while_is_sharded(module.get(), sharding);
|
||||||
}
|
}
|
||||||
@ -921,11 +942,30 @@ ENTRY %entry {
|
|||||||
ParseAndReturnVerifiedModule(hlo_string));
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
auto acc_1 = FindInstruction(module.get(), "acc.1");
|
auto acc_1 = FindInstruction(module.get(), "acc.1");
|
||||||
EXPECT_NE(nullptr, acc_1);
|
EXPECT_NE(nullptr, acc_1);
|
||||||
acc_1->set_sharding(ParseSharding("{devices=[2]0,1}").ConsumeValueOrDie());
|
acc_1->set_sharding(
|
||||||
|
ParseSharding("{devices=[2,1]0,1}").ConsumeValueOrDie());
|
||||||
|
|
||||||
while_is_sharded(
|
while_is_sharded(module.get(),
|
||||||
module.get(),
|
ParseSharding("{{replicated}, {devices=[2,1]0,1}}")
|
||||||
ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie());
|
.ConsumeValueOrDie());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Merge partial sharding from operand and body.
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
auto acc_1 = FindInstruction(module.get(), "acc.1");
|
||||||
|
EXPECT_NE(nullptr, acc_1);
|
||||||
|
acc_1->set_sharding(
|
||||||
|
ParseSharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")
|
||||||
|
.ConsumeValueOrDie());
|
||||||
|
auto p0 = FindInstruction(module.get(), "p0");
|
||||||
|
p0->set_sharding(
|
||||||
|
ParseSharding("{devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}")
|
||||||
|
.ConsumeValueOrDie());
|
||||||
|
|
||||||
|
while_is_sharded(module.get(),
|
||||||
|
ParseSharding("{{replicated}, {devices=[2,2]0,1,2,3}}")
|
||||||
|
.ConsumeValueOrDie());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,6 +182,10 @@ class ConvolutionVisitor {
|
|||||||
return permute_dims[id];
|
return permute_dims[id];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 ReverseDimLookUp(absl::Span<const int64> permute_dims, int64 id) {
|
||||||
|
return std::distance(permute_dims.begin(), absl::c_find(permute_dims, id));
|
||||||
|
}
|
||||||
|
|
||||||
HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
|
HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
|
||||||
HloInstruction* instr, int64 depth);
|
HloInstruction* instr, int64 depth);
|
||||||
|
|
||||||
@ -215,9 +219,10 @@ class ConvolutionVisitor {
|
|||||||
// Limit on batch size to apply this technique on.
|
// Limit on batch size to apply this technique on.
|
||||||
int64 limit_on_batch_size_;
|
int64 limit_on_batch_size_;
|
||||||
|
|
||||||
// We choose the new batch size to be a constant so that space-to-batch
|
// We choose the new batch size to be kNumSplits times that of the old batch
|
||||||
// propagation through several convolutional layers is consistent.
|
// so that space-to-batch propagation through several convolutional layers is
|
||||||
static constexpr int64 kNewBatchSize = 8;
|
// consistent.
|
||||||
|
static constexpr int64 kNumSplits = 8;
|
||||||
|
|
||||||
// Depth for searching reduce window
|
// Depth for searching reduce window
|
||||||
static constexpr int64 kReduceWindowSearchDepth = 10;
|
static constexpr int64 kReduceWindowSearchDepth = 10;
|
||||||
@ -301,17 +306,12 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
|
|||||||
if (old_batch_size > limit_on_batch_size_) {
|
if (old_batch_size > limit_on_batch_size_) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// We currently only cater to evenly divisible cases.
|
|
||||||
if (kNewBatchSize % old_batch_size != 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(1) << "spatial size " << c.spatial_size;
|
VLOG(1) << "spatial size " << c.spatial_size;
|
||||||
|
|
||||||
const int64 num_splits = kNewBatchSize / old_batch_size;
|
|
||||||
// If the ratio is not within the 2X range, we can't Halo Pad from the next
|
// If the ratio is not within the 2X range, we can't Halo Pad from the next
|
||||||
// split.
|
// split.
|
||||||
if (c.halo_size > CeilOfRatio(c.spatial_size, num_splits)) {
|
if (c.halo_size > CeilOfRatio(c.spatial_size, kNumSplits)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
|
VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
|
||||||
@ -323,6 +323,24 @@ StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
|
|||||||
int64 activations_batch_dim, int64 old_batch_size, int64 low_padding,
|
int64 activations_batch_dim, int64 old_batch_size, int64 low_padding,
|
||||||
int64 high_padding, int64 halo_size, int64 original_split_dim_size,
|
int64 high_padding, int64 halo_size, int64 original_split_dim_size,
|
||||||
HloInstruction* pad_val) {
|
HloInstruction* pad_val) {
|
||||||
|
const int64 original_batch_size =
|
||||||
|
activations->shape().dimensions(activations_batch_dim) / kNumSplits;
|
||||||
|
|
||||||
|
if (original_batch_size > 1) {
|
||||||
|
std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
|
||||||
|
activations->shape().dimensions().end());
|
||||||
|
new_dimensions[activations_batch_dim] = kNumSplits;
|
||||||
|
new_dimensions.insert(new_dimensions.begin() + activations_batch_dim,
|
||||||
|
original_batch_size);
|
||||||
|
|
||||||
|
// Reshape the output of the new conv into the old convolutions shape.
|
||||||
|
TF_ASSIGN_OR_RETURN(activations,
|
||||||
|
MakeReshapeHlo(new_dimensions, activations));
|
||||||
|
|
||||||
|
spatial_dimension_to_split++;
|
||||||
|
activations_batch_dim++;
|
||||||
|
}
|
||||||
|
|
||||||
const int64 rank = activations->shape().rank();
|
const int64 rank = activations->shape().rank();
|
||||||
const int64 spatial_split_size =
|
const int64 spatial_split_size =
|
||||||
activations->shape().dimensions(spatial_dimension_to_split);
|
activations->shape().dimensions(spatial_dimension_to_split);
|
||||||
@ -415,6 +433,21 @@ StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
|
|||||||
TF_ASSIGN_OR_RETURN(activations, MakeConcatHlo({activations, halo_region},
|
TF_ASSIGN_OR_RETURN(activations, MakeConcatHlo({activations, halo_region},
|
||||||
spatial_dimension_to_split));
|
spatial_dimension_to_split));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (original_batch_size > 1) {
|
||||||
|
std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
|
||||||
|
activations->shape().dimensions().end());
|
||||||
|
new_dimensions[activations_batch_dim] = original_batch_size * kNumSplits;
|
||||||
|
new_dimensions.erase(new_dimensions.begin() + activations_batch_dim - 1);
|
||||||
|
|
||||||
|
// Reshape the output of the new conv into the old convolutions shape.
|
||||||
|
TF_ASSIGN_OR_RETURN(activations,
|
||||||
|
MakeReshapeHlo(new_dimensions, activations));
|
||||||
|
|
||||||
|
spatial_dimension_to_split++;
|
||||||
|
activations_batch_dim++;
|
||||||
|
}
|
||||||
|
|
||||||
VLOG(1) << "HaloDuplicated activations " << activations->ToString();
|
VLOG(1) << "HaloDuplicated activations " << activations->ToString();
|
||||||
return activations;
|
return activations;
|
||||||
}
|
}
|
||||||
@ -424,17 +457,20 @@ ConvolutionVisitor::BringSpaceNextToBatch(
|
|||||||
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
|
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
|
||||||
int64& spatial_dimension_to_split, int64& activations_batch_dim,
|
int64& spatial_dimension_to_split, int64& activations_batch_dim,
|
||||||
bool is_backprop) {
|
bool is_backprop) {
|
||||||
std::vector<int64> transpose_dims;
|
std::vector<int64> transpose_dims(activations->shape().rank());
|
||||||
ConvolutionDimensionNumbers new_dim_numbers = dim_numbers;
|
if (spatial_dimension_to_split == activations_batch_dim + 1) {
|
||||||
if (spatial_dimension_to_split != activations_batch_dim + 1) {
|
absl::c_iota(transpose_dims, 0);
|
||||||
|
} else {
|
||||||
|
ConvolutionDimensionNumbers new_dim_numbers = dim_numbers;
|
||||||
int64 pushed_counter = 0;
|
int64 pushed_counter = 0;
|
||||||
int64 new_batch_dim, new_spatial_dim;
|
int64 new_batch_dim, new_spatial_dim;
|
||||||
|
int64 dim_counter = 0;
|
||||||
for (int i = 0; i < activations->shape().rank(); ++i) {
|
for (int i = 0; i < activations->shape().rank(); ++i) {
|
||||||
if (i == activations_batch_dim) {
|
if (i == activations_batch_dim) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (i == spatial_dimension_to_split) {
|
if (i == spatial_dimension_to_split) {
|
||||||
transpose_dims.push_back(activations_batch_dim);
|
transpose_dims[dim_counter++] = activations_batch_dim;
|
||||||
new_batch_dim = pushed_counter;
|
new_batch_dim = pushed_counter;
|
||||||
pushed_counter++;
|
pushed_counter++;
|
||||||
new_spatial_dim = pushed_counter;
|
new_spatial_dim = pushed_counter;
|
||||||
@ -452,7 +488,7 @@ ConvolutionVisitor::BringSpaceNextToBatch(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
transpose_dims.push_back(i);
|
transpose_dims[dim_counter++] = i;
|
||||||
pushed_counter++;
|
pushed_counter++;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -460,14 +496,14 @@ ConvolutionVisitor::BringSpaceNextToBatch(
|
|||||||
spatial_dimension_to_split = new_spatial_dim;
|
spatial_dimension_to_split = new_spatial_dim;
|
||||||
TF_ASSIGN_OR_RETURN(activations,
|
TF_ASSIGN_OR_RETURN(activations,
|
||||||
MakeTransposeHlo(activations, transpose_dims));
|
MakeTransposeHlo(activations, transpose_dims));
|
||||||
}
|
|
||||||
|
|
||||||
if (is_backprop) {
|
if (is_backprop) {
|
||||||
new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
|
new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
|
||||||
} else {
|
} else {
|
||||||
new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
|
new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
|
||||||
|
}
|
||||||
|
dim_numbers = new_dim_numbers;
|
||||||
}
|
}
|
||||||
dim_numbers = new_dim_numbers;
|
|
||||||
|
|
||||||
return SpaceNextToBatchDetails{activations, transpose_dims};
|
return SpaceNextToBatchDetails{activations, transpose_dims};
|
||||||
}
|
}
|
||||||
@ -586,12 +622,23 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
|||||||
VLOG(1) << "Checking if conv is supported for propagation "
|
VLOG(1) << "Checking if conv is supported for propagation "
|
||||||
<< consumer->ToString();
|
<< consumer->ToString();
|
||||||
if (IsConvSuitableForSpaceToBatch(consumer)) {
|
if (IsConvSuitableForSpaceToBatch(consumer)) {
|
||||||
for (int64 i = 0; i < consumer->operand_count(); ++i) {
|
if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) {
|
||||||
auto old_producer = consumer->mutable_operand(i);
|
return false;
|
||||||
if (i == 0 && !old_to_new_instrs_.contains(old_producer)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
|
||||||
|
// Make sure that the space dimension is the same across the producer
|
||||||
|
// and consumer.
|
||||||
|
if (consumer->convolution_dimension_numbers().input_spatial_dimensions(
|
||||||
|
get_chosen_spatial_dim(consumer)) != dim_map_val_op_0.second) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Make sure that the batch dimension is the same across the producer
|
||||||
|
// and consumer.
|
||||||
|
if (consumer->convolution_dimension_numbers().input_batch_dimension() !=
|
||||||
|
dim_map_val_op_0.first) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -611,13 +658,35 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
|||||||
VLOG(2) << "Checking for backprop filter conv operands "
|
VLOG(2) << "Checking for backprop filter conv operands "
|
||||||
<< consumer->operand_count();
|
<< consumer->operand_count();
|
||||||
|
|
||||||
if (!old_to_new_instrs_.contains(consumer->mutable_operand(1))) {
|
auto activations = consumer->mutable_operand(0);
|
||||||
|
auto kernel = consumer->mutable_operand(1);
|
||||||
|
|
||||||
|
if (!old_to_new_instrs_.contains(kernel)) {
|
||||||
VLOG(2) << "Backprop filter conv not ready for propagation because of "
|
VLOG(2) << "Backprop filter conv not ready for propagation because of "
|
||||||
"kernel is not space-to-batched";
|
"kernel is not space-to-batched";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) {
|
if (!old_to_new_instrs_.contains(activations)) {
|
||||||
|
const int64 lhs_batch = activations->shape().dimensions(
|
||||||
|
consumer->convolution_dimension_numbers().input_feature_dimension());
|
||||||
|
auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)];
|
||||||
|
const int64 old_batch_dim = dim_map_val_op_1.first;
|
||||||
|
auto second_operand = old_to_new_instrs_[kernel];
|
||||||
|
auto permute_dims_second_operand =
|
||||||
|
instr_to_dim_permute_map_[second_operand];
|
||||||
|
const int64 new_batch_dim =
|
||||||
|
DimLookUp(permute_dims_second_operand, old_batch_dim);
|
||||||
|
const int64 rhs_batch = second_operand->shape().dimensions(new_batch_dim);
|
||||||
|
|
||||||
|
// Because we want to convert activations into a space-to-batched version
|
||||||
|
// only for backprop filter convolutions, we want to make sure that the
|
||||||
|
// batch dimensions (feature dimensions, technically) are same sized.
|
||||||
|
// Since RHS is already space-to-batched, we need to account for it too.
|
||||||
|
if (rhs_batch != kNumSplits * lhs_batch) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// If activations have not been propagated through, we can do
|
// If activations have not been propagated through, we can do
|
||||||
// space-to-batch on them provided kernel has been propagated.
|
// space-to-batch on them provided kernel has been propagated.
|
||||||
VLOG(2) << "Backprop filter conv ready for propagation: kernel ready, "
|
VLOG(2) << "Backprop filter conv ready for propagation: kernel ready, "
|
||||||
@ -625,10 +694,10 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
|
auto first_operand = old_to_new_instrs_[activations];
|
||||||
auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
|
auto dim_map_val_op_0 = instr_to_dim_map_[activations];
|
||||||
auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)];
|
auto second_operand = old_to_new_instrs_[kernel];
|
||||||
auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)];
|
auto dim_map_val_op_1 = instr_to_dim_map_[kernel];
|
||||||
|
|
||||||
auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
|
auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
|
||||||
auto permute_dims_second_operand =
|
auto permute_dims_second_operand =
|
||||||
@ -1119,7 +1188,7 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
|
|||||||
|
|
||||||
Window new_win;
|
Window new_win;
|
||||||
for (int64 i = 0; i < consumer->window().dimensions().size(); ++i) {
|
for (int64 i = 0; i < consumer->window().dimensions().size(); ++i) {
|
||||||
auto dim = DimLookUp(permute_dims, i);
|
auto dim = ReverseDimLookUp(permute_dims, i);
|
||||||
new_win.add_dimensions();
|
new_win.add_dimensions();
|
||||||
new_win.mutable_dimensions(i)->set_stride(
|
new_win.mutable_dimensions(i)->set_stride(
|
||||||
consumer->window().dimensions(dim).stride());
|
consumer->window().dimensions(dim).stride());
|
||||||
@ -1339,7 +1408,9 @@ StatusOr<HloInstruction*> ConvolutionVisitor::SelectValidPortion(
|
|||||||
const int64 new_space_size = new_shape.dimensions(new_space_dim);
|
const int64 new_space_size = new_shape.dimensions(new_space_dim);
|
||||||
const int64 old_batch_size = old_shape.dimensions(old_batch_dim);
|
const int64 old_batch_size = old_shape.dimensions(old_batch_dim);
|
||||||
const int64 old_space_size = old_shape.dimensions(old_space_dim);
|
const int64 old_space_size = old_shape.dimensions(old_space_dim);
|
||||||
CHECK_EQ(new_batch_size % old_batch_size, 0);
|
CHECK_EQ(new_batch_size % old_batch_size, 0)
|
||||||
|
<< " New batch size " << new_batch_size << " old batch size "
|
||||||
|
<< old_batch_size;
|
||||||
const int64 num_splits = new_batch_size / old_batch_size;
|
const int64 num_splits = new_batch_size / old_batch_size;
|
||||||
// Build a constant PRED to decide which elements in the split dimension
|
// Build a constant PRED to decide which elements in the split dimension
|
||||||
// are from halo.
|
// are from halo.
|
||||||
@ -1394,8 +1465,10 @@ StatusOr<HloInstruction*> ConvolutionVisitor::BatchToSpace(
|
|||||||
CHECK(old_to_new_instrs_.contains(old_instr));
|
CHECK(old_to_new_instrs_.contains(old_instr));
|
||||||
auto new_instr = old_to_new_instrs_[old_instr];
|
auto new_instr = old_to_new_instrs_[old_instr];
|
||||||
VLOG(2) << "old_batch_dim " << old_batch_dim << " old_space_dim "
|
VLOG(2) << "old_batch_dim " << old_batch_dim << " old_space_dim "
|
||||||
<< old_space_dim << " new_instr " << new_instr->ToString()
|
<< old_space_dim << " old_instr " << old_instr->ToString()
|
||||||
<< " permute dims " << instr_to_dim_permute_map_.count(new_instr);
|
<< "\n new_instr " << new_instr->ToString() << " permute dims "
|
||||||
|
<< instr_to_dim_permute_map_.count(new_instr) << " old_batch_size "
|
||||||
|
<< old_batch_size;
|
||||||
CHECK(instr_to_dim_permute_map_.contains(new_instr));
|
CHECK(instr_to_dim_permute_map_.contains(new_instr));
|
||||||
auto permute_dims = instr_to_dim_permute_map_[new_instr];
|
auto permute_dims = instr_to_dim_permute_map_[new_instr];
|
||||||
const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
|
const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
|
||||||
@ -1565,6 +1638,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
|||||||
c.spatial_dimension_to_split, activations_batch_dim));
|
c.spatial_dimension_to_split, activations_batch_dim));
|
||||||
activations_new = retval.instr;
|
activations_new = retval.instr;
|
||||||
std::vector<int64> trans_dims = retval.transpose_dims;
|
std::vector<int64> trans_dims = retval.transpose_dims;
|
||||||
|
CHECK(!trans_dims.empty());
|
||||||
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
|
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||||
LiteralUtil::Zero(activations_new->shape().element_type())));
|
LiteralUtil::Zero(activations_new->shape().element_type())));
|
||||||
|
|
||||||
@ -1578,8 +1652,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
|||||||
|
|
||||||
VLOG(1) << "spatial size " << c.spatial_size;
|
VLOG(1) << "spatial size " << c.spatial_size;
|
||||||
|
|
||||||
const int64 num_splits = kNewBatchSize / old_batch_size;
|
const int64 num_splits = kNumSplits;
|
||||||
|
|
||||||
const int64 output_offsets = convolution->shape().dimensions(
|
const int64 output_offsets = convolution->shape().dimensions(
|
||||||
permuted_conv_dims_numbers.output_spatial_dimensions(
|
permuted_conv_dims_numbers.output_spatial_dimensions(
|
||||||
get_chosen_spatial_dim(convolution)));
|
get_chosen_spatial_dim(convolution)));
|
||||||
@ -1614,6 +1687,8 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
|||||||
activations_new->shape().dimensions().end());
|
activations_new->shape().dimensions().end());
|
||||||
const int64 reshaped_space_size =
|
const int64 reshaped_space_size =
|
||||||
new_space_size * new_batch_size / old_batch_size;
|
new_space_size * new_batch_size / old_batch_size;
|
||||||
|
VLOG(3) << "Increasing the spatial size while propagating new_batch_size "
|
||||||
|
<< new_batch_size << " old_batch_size " << old_batch_size;
|
||||||
new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size;
|
new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size;
|
||||||
new_dimensions[activations_batch_dim] = old_batch_size;
|
new_dimensions[activations_batch_dim] = old_batch_size;
|
||||||
|
|
||||||
@ -1621,10 +1696,12 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
|||||||
TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
|
TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
|
||||||
MakeReshapeHlo(new_dimensions, activations_new));
|
MakeReshapeHlo(new_dimensions, activations_new));
|
||||||
|
|
||||||
|
VLOG(3) << "First reshape done";
|
||||||
PaddingConfig padding_config =
|
PaddingConfig padding_config =
|
||||||
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
|
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
|
||||||
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
|
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
|
||||||
->set_edge_padding_high(spatial_split_size * new_batch_size -
|
->set_edge_padding_high(spatial_split_size * new_batch_size /
|
||||||
|
old_batch_size -
|
||||||
reshaped_space_size);
|
reshaped_space_size);
|
||||||
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
|
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
|
||||||
->set_edge_padding_low(0);
|
->set_edge_padding_low(0);
|
||||||
@ -1647,6 +1724,8 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
|||||||
reshaped_activations,
|
reshaped_activations,
|
||||||
MakeReshapeHlo(reshape_back_dims, reshaped_activations));
|
MakeReshapeHlo(reshape_back_dims, reshaped_activations));
|
||||||
|
|
||||||
|
VLOG(3) << "Second reshape done";
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
activations_new,
|
activations_new,
|
||||||
HaloDuplicateWithSlice(
|
HaloDuplicateWithSlice(
|
||||||
@ -1664,6 +1743,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
|||||||
// additional space available, and adjust the required slice size (and
|
// additional space available, and adjust the required slice size (and
|
||||||
// thereby the halo size).
|
// thereby the halo size).
|
||||||
if (spatial_split_size < new_space_size) {
|
if (spatial_split_size < new_space_size) {
|
||||||
|
VLOG(3) << "Decreasing the spatial size while propagating";
|
||||||
const int64 additional_space_present = spatial_split_size % c.stride;
|
const int64 additional_space_present = spatial_split_size % c.stride;
|
||||||
spatial_split_size = new_space_size;
|
spatial_split_size = new_space_size;
|
||||||
slice_size =
|
slice_size =
|
||||||
@ -1758,6 +1838,7 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations,
|
|||||||
|
|
||||||
activations = retval.instr;
|
activations = retval.instr;
|
||||||
std::vector<int64> transpose_dims = retval.transpose_dims;
|
std::vector<int64> transpose_dims = retval.transpose_dims;
|
||||||
|
CHECK(!transpose_dims.empty());
|
||||||
// Because we are splitting the spatial dimension, if convolution needed
|
// Because we are splitting the spatial dimension, if convolution needed
|
||||||
// padding in the spatial dimension, we materialize it.
|
// padding in the spatial dimension, we materialize it.
|
||||||
if (high_padding || low_padding) {
|
if (high_padding || low_padding) {
|
||||||
@ -1774,7 +1855,9 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations,
|
|||||||
MakePadHlo(activations, padding, padding_config));
|
MakePadHlo(activations, padding, padding_config));
|
||||||
}
|
}
|
||||||
VLOG(1) << "Initial padded activations shape "
|
VLOG(1) << "Initial padded activations shape "
|
||||||
<< activations->shape().ToString();
|
<< activations->shape().ToString() << " old_batch_size "
|
||||||
|
<< old_batch_size << " activations_batch_dim "
|
||||||
|
<< activations_batch_dim;
|
||||||
|
|
||||||
// Now we reorganize the activations. E.g. if the shape [B, SPACE] was [1, 16]
|
// Now we reorganize the activations. E.g. if the shape [B, SPACE] was [1, 16]
|
||||||
// and 4 splits were needed, we first create [4, 4]. Next, to deal with halo
|
// and 4 splits were needed, we first create [4, 4]. Next, to deal with halo
|
||||||
@ -1829,7 +1912,10 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
CHECK(old_to_new_instrs_.contains(kernel_old));
|
CHECK(old_to_new_instrs_.contains(kernel_old));
|
||||||
auto kernel_new = old_to_new_instrs_[kernel_old];
|
auto kernel_new = old_to_new_instrs_[kernel_old];
|
||||||
|
|
||||||
|
auto permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
|
||||||
|
|
||||||
HloInstruction* activations_new = nullptr;
|
HloInstruction* activations_new = nullptr;
|
||||||
|
bool activations_locally_space_to_batched = false;
|
||||||
// If activations were no space-to-batched, we space-to-batch them below.
|
// If activations were no space-to-batched, we space-to-batch them below.
|
||||||
if (!old_to_new_instrs_.contains(activations_old)) {
|
if (!old_to_new_instrs_.contains(activations_old)) {
|
||||||
VLOG(1) << "Space-to-batching activations to enable space-to-depth";
|
VLOG(1) << "Space-to-batching activations to enable space-to-depth";
|
||||||
@ -1838,28 +1924,34 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
instr_to_dim_map_[activations_old] =
|
instr_to_dim_map_[activations_old] =
|
||||||
std::make_pair(prev_feature_dim, prev_batch_dim);
|
std::make_pair(prev_feature_dim, prev_batch_dim);
|
||||||
|
|
||||||
int64 activations_batch_dim = original_conv_dims.input_feature_dimension();
|
const int64 new_kernel_space_dim =
|
||||||
const int64 old_batch_size =
|
DimLookUp(permute_dims_kernel, kernel_space_dim);
|
||||||
activations_old->shape().dimensions(activations_batch_dim);
|
|
||||||
const int64 num_splits = kNewBatchSize / old_batch_size;
|
|
||||||
const int64 new_kernel_split_dim_size =
|
const int64 new_kernel_split_dim_size =
|
||||||
kernel_new->shape().dimensions(kernel_space_dim);
|
kernel_new->shape().dimensions(new_kernel_space_dim);
|
||||||
const int64 needed_spatial_size = rhs_dilation * new_kernel_split_dim_size;
|
const int64 needed_spatial_size = rhs_dilation * new_kernel_split_dim_size;
|
||||||
const int64 pad_size =
|
const int64 pad_size =
|
||||||
needed_spatial_size * num_splits - old_split_dim_size;
|
needed_spatial_size * kNumSplits - old_split_dim_size;
|
||||||
ConvolutionDimensionNumbers tmp_dim_numbers;
|
ConvolutionDimensionNumbers tmp_dim_numbers;
|
||||||
tmp_dim_numbers = original_conv_dims;
|
tmp_dim_numbers = original_conv_dims;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto retval,
|
auto retval,
|
||||||
SplitSpace(activations_old, tmp_dim_numbers, old_space_dim,
|
SplitSpace(activations_old, tmp_dim_numbers, old_space_dim,
|
||||||
activations_batch_dim,
|
old_batch_dim,
|
||||||
/*high_padding=*/pad_size, /*low_padding=*/0,
|
/*high_padding=*/pad_size, /*low_padding=*/0,
|
||||||
needed_spatial_size, num_splits, /*is_backprop=*/true));
|
needed_spatial_size, kNumSplits, /*is_backprop=*/true));
|
||||||
|
|
||||||
old_to_new_instrs_[activations_old] = retval.first;
|
old_to_new_instrs_[activations_old] = retval.first;
|
||||||
instr_to_dim_permute_map_[retval.first] = retval.second;
|
|
||||||
|
|
||||||
VLOG(3) << "Edited conv dims " << original_conv_dims.DebugString();
|
std::vector<int64> reversed_transpose_dims(retval.second.size());
|
||||||
|
for (int64 i = 0; i < retval.second.size(); ++i) {
|
||||||
|
reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
|
||||||
|
}
|
||||||
|
instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims;
|
||||||
|
|
||||||
|
VLOG(3) << "New Activations " << retval.first->ToString();
|
||||||
|
|
||||||
|
activations_locally_space_to_batched = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK(old_to_new_instrs_.contains(activations_old));
|
CHECK(old_to_new_instrs_.contains(activations_old));
|
||||||
@ -1884,7 +1976,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
i, DimLookUp(permute_dims,
|
i, DimLookUp(permute_dims,
|
||||||
original_conv_dims.input_spatial_dimensions(i)));
|
original_conv_dims.input_spatial_dimensions(i)));
|
||||||
permuted_conv_dims_numbers.set_kernel_spatial_dimensions(
|
permuted_conv_dims_numbers.set_kernel_spatial_dimensions(
|
||||||
i, DimLookUp(permute_dims,
|
i, DimLookUp(permute_dims_kernel,
|
||||||
original_conv_dims.kernel_spatial_dimensions(i)));
|
original_conv_dims.kernel_spatial_dimensions(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1905,10 +1997,11 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
previous_spatial_dim_count, previous_chosen_spatial_dim_in_output);
|
previous_spatial_dim_count, previous_chosen_spatial_dim_in_output);
|
||||||
|
|
||||||
const int64 kernel_input_feature_dim = DimLookUp(
|
const int64 kernel_input_feature_dim = DimLookUp(
|
||||||
permute_dims, original_conv_dims.kernel_input_feature_dimension());
|
permute_dims_kernel, original_conv_dims.kernel_input_feature_dimension());
|
||||||
|
|
||||||
const int64 kernel_output_feature_dim = DimLookUp(
|
const int64 kernel_output_feature_dim =
|
||||||
permute_dims, original_conv_dims.kernel_output_feature_dimension());
|
DimLookUp(permute_dims_kernel,
|
||||||
|
original_conv_dims.kernel_output_feature_dimension());
|
||||||
|
|
||||||
permuted_conv_dims_numbers.set_kernel_input_feature_dimension(
|
permuted_conv_dims_numbers.set_kernel_input_feature_dimension(
|
||||||
kernel_input_feature_dim);
|
kernel_input_feature_dim);
|
||||||
@ -1931,7 +2024,8 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
|
|
||||||
VLOG(1) << "Propagating on conv activations_batch_dim "
|
VLOG(1) << "Propagating on conv activations_batch_dim "
|
||||||
<< activations_batch_dim << " spatial_dimension_to_split "
|
<< activations_batch_dim << " spatial_dimension_to_split "
|
||||||
<< spatial_dimension_to_split << " old_batch_size " << old_batch_size;
|
<< spatial_dimension_to_split << " old_batch_size " << old_batch_size
|
||||||
|
<< " new_split_dim_size " << new_split_dim_size;
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto retval,
|
auto retval,
|
||||||
@ -1939,6 +2033,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
spatial_dimension_to_split, activations_batch_dim,
|
spatial_dimension_to_split, activations_batch_dim,
|
||||||
/*is_backprop=*/true));
|
/*is_backprop=*/true));
|
||||||
std::vector<int64> transpose_dims = retval.transpose_dims;
|
std::vector<int64> transpose_dims = retval.transpose_dims;
|
||||||
|
CHECK(!transpose_dims.empty());
|
||||||
activations_new = retval.instr;
|
activations_new = retval.instr;
|
||||||
|
|
||||||
VLOG(1) << "Activations_new post BringSpaceNextToBatch "
|
VLOG(1) << "Activations_new post BringSpaceNextToBatch "
|
||||||
@ -1949,13 +2044,15 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
|
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||||
LiteralUtil::Zero(activations_new->shape().element_type())));
|
LiteralUtil::Zero(activations_new->shape().element_type())));
|
||||||
|
|
||||||
// Select activations correctly by masking additional space.
|
if (!activations_locally_space_to_batched) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
// Select activations correctly by masking additional space.
|
||||||
activations_new,
|
TF_ASSIGN_OR_RETURN(
|
||||||
SelectValidPortion(activations_new, activations_old, select_val,
|
activations_new,
|
||||||
activations_batch_dim, spatial_dimension_to_split,
|
SelectValidPortion(activations_new, activations_old, select_val,
|
||||||
old_batch_dim, old_space_dim));
|
activations_batch_dim, spatial_dimension_to_split,
|
||||||
|
old_batch_dim, old_space_dim));
|
||||||
|
}
|
||||||
|
VLOG(3) << "Selecting the valid kernel area";
|
||||||
// Select kernel correctly by masking additional space.
|
// Select kernel correctly by masking additional space.
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
kernel_new,
|
kernel_new,
|
||||||
@ -2238,7 +2335,6 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
|||||||
|
|
||||||
VLOG(1) << "spatial size " << c.spatial_size;
|
VLOG(1) << "spatial size " << c.spatial_size;
|
||||||
|
|
||||||
const int64 num_splits = kNewBatchSize / old_batch_size;
|
|
||||||
auto original_conv = convolution;
|
auto original_conv = convolution;
|
||||||
|
|
||||||
const int64 output_spatial_dim = dim_numbers.output_spatial_dimensions(
|
const int64 output_spatial_dim = dim_numbers.output_spatial_dimensions(
|
||||||
@ -2246,13 +2342,13 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
|||||||
const int64 output_offsets =
|
const int64 output_offsets =
|
||||||
convolution->shape().dimensions(output_spatial_dim);
|
convolution->shape().dimensions(output_spatial_dim);
|
||||||
const int64 output_offsets_per_split =
|
const int64 output_offsets_per_split =
|
||||||
CeilOfRatio(output_offsets, num_splits);
|
CeilOfRatio(output_offsets, kNumSplits);
|
||||||
|
|
||||||
int64 spatial_split_size =
|
int64 spatial_split_size =
|
||||||
CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
|
CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
|
||||||
// Keep increasing the split size so that overall size isn't smaller than the
|
// Keep increasing the split size so that overall size isn't smaller than the
|
||||||
// original spatial dimension.
|
// original spatial dimension.
|
||||||
while (spatial_split_size * num_splits - c.spatial_size < 0) {
|
while (spatial_split_size * kNumSplits - c.spatial_size < 0) {
|
||||||
spatial_split_size += c.stride;
|
spatial_split_size += c.stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2276,12 +2372,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
|||||||
const int64 slice_size = spatial_split_size + c.halo_size;
|
const int64 slice_size = spatial_split_size + c.halo_size;
|
||||||
|
|
||||||
// Pad spatial dim.
|
// Pad spatial dim.
|
||||||
const int64 pad_size = spatial_split_size * num_splits - c.spatial_size;
|
const int64 pad_size = spatial_split_size * kNumSplits - c.spatial_size;
|
||||||
|
|
||||||
VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
|
VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
|
||||||
<< c.stride << " slice_size " << slice_size;
|
<< c.stride << " slice_size " << slice_size;
|
||||||
VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split
|
VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split
|
||||||
<< " num_splits " << num_splits << " kernel_spatial_dim_size "
|
<< " num_splits " << kNumSplits << " kernel_spatial_dim_size "
|
||||||
<< c.kernel_spatial_dim_size;
|
<< c.kernel_spatial_dim_size;
|
||||||
int64 spatial_dimension_to_split = c.spatial_dimension_to_split;
|
int64 spatial_dimension_to_split = c.spatial_dimension_to_split;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
@ -2292,7 +2388,7 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
|||||||
/*low_padding=*/c.base_dilation_factor == 1
|
/*low_padding=*/c.base_dilation_factor == 1
|
||||||
? c.inherent_low_padding
|
? c.inherent_low_padding
|
||||||
: 0,
|
: 0,
|
||||||
spatial_split_size, num_splits));
|
spatial_split_size, kNumSplits));
|
||||||
HloInstruction* batch_increased_reshape = retval.first;
|
HloInstruction* batch_increased_reshape = retval.first;
|
||||||
convolution->SetupDerivedInstruction(batch_increased_reshape);
|
convolution->SetupDerivedInstruction(batch_increased_reshape);
|
||||||
|
|
||||||
|
@ -317,5 +317,27 @@ TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DynamismInferenceTest, InferThroughPad) {
|
||||||
|
for (ClientType client_type : client_types) {
|
||||||
|
Client* client = ClientOrDie(platform_, client_type);
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
// Test the analysis on a gather.
|
||||||
|
auto operand1 = ConstantR1<int32>(&b, {1, 2});
|
||||||
|
auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
|
||||||
|
PaddingConfig padding_config;
|
||||||
|
padding_config.add_dimensions()->set_edge_padding_high(1);
|
||||||
|
// After pad the value is [constant, constant, parameter].
|
||||||
|
auto pad = Pad(operand1, parameter, padding_config);
|
||||||
|
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
|
||||||
|
// Everything is constant, result is also contant.
|
||||||
|
EXPECT_FALSE(
|
||||||
|
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({0}));
|
||||||
|
EXPECT_FALSE(
|
||||||
|
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({1}));
|
||||||
|
EXPECT_TRUE(
|
||||||
|
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({2}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -57,7 +57,8 @@ class GpuDummyCompiler : public GpuCompiler {
|
|||||||
|
|
||||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
||||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) {
|
GpuVersion gpu_version, se::StreamExecutor* stream_exec,
|
||||||
|
bool relocatable) {
|
||||||
if (user_post_optimization_hook_) {
|
if (user_post_optimization_hook_) {
|
||||||
user_post_optimization_hook_(*llvm_module);
|
user_post_optimization_hook_(*llvm_module);
|
||||||
}
|
}
|
||||||
|
@ -20,9 +20,8 @@ op {
|
|||||||
"A `Tensor` of type T. An alias of `x`. The content "
|
"A `Tensor` of type T. An alias of `x`. The content "
|
||||||
"of `y` is undefined if there are duplicates in `i`."
|
"of `y` is undefined if there are duplicates in `i`."
|
||||||
}
|
}
|
||||||
summary: <<END
|
summary: "Adds v into specified rows of x."
|
||||||
Adds v into specified rows of x.
|
description: <<END
|
||||||
|
|
||||||
Computes y = x; y[i, :] += v; return y.
|
Computes y = x; y[i, :] += v; return y.
|
||||||
END
|
END
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
op {
|
op {
|
||||||
graph_op_name: "TopKUnique"
|
graph_op_name: "TopKUnique"
|
||||||
summary: "Returns the TopK unique values in the array in sorted order. The"
|
summary: "Returns the TopK unique values in the array in sorted order."
|
||||||
description: <<END
|
description: <<END
|
||||||
running time is proportional to the product of K and the input
|
The running time is proportional to the product of K and the input
|
||||||
size. Sorting the whole array is more efficient for sufficiently large
|
size. Sorting the whole array is more efficient for sufficiently large
|
||||||
values of K. The median-of-medians algorithm is probably faster, but
|
values of K. The median-of-medians algorithm is probably faster, but
|
||||||
difficult to implement efficiently in XLA. If there are fewer than K
|
difficult to implement efficiently in XLA. If there are fewer than K
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
op {
|
op {
|
||||||
graph_op_name: "TopKWithUnique"
|
graph_op_name: "TopKWithUnique"
|
||||||
summary: "Returns the TopK values in the array in sorted order. This is a combination"
|
summary: "Returns the TopK values in the array in sorted order."
|
||||||
description: <<END
|
description: <<END
|
||||||
of MakeUnique and TopKUnique. The returned top-K will have its lower bits
|
This is a combination of MakeUnique and TopKUnique. The returned top-K will
|
||||||
replaced by iota, thus it will be close to the original value but not exactly
|
have its lower bits replaced by iota, thus it will be close to the original
|
||||||
the same. The running time is proportional to the product of K and the input
|
value but not exactly the same. The running time is proportional to the product
|
||||||
size. NaNs are never returned. Subnormal numbers are flushed to zero.
|
of K and the input size. NaNs are never returned. Subnormal numbers are flushed
|
||||||
|
to zero.
|
||||||
END
|
END
|
||||||
}
|
}
|
||||||
|
@ -705,10 +705,10 @@ Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EagerContext::AddFunctionDefWithDebugInfo(
|
Status EagerContext::AddFunctionDefWithStackTraces(
|
||||||
const FunctionDef& fdef, const Graph* graph_with_debug_info) {
|
const FunctionDef& fdef, const StackTracesMap& stack_traces) {
|
||||||
return AddFunctionDef(fdef, FunctionDefLibrary(),
|
return AddFunctionDef(fdef, FunctionDefLibrary(),
|
||||||
/* add_to_local_only=*/false, graph_with_debug_info);
|
/* add_to_local_only=*/false, stack_traces);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
|
Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
|
||||||
@ -719,7 +719,7 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
|
|||||||
Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
|
Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
|
||||||
const FunctionDefLibrary& library,
|
const FunctionDefLibrary& library,
|
||||||
const bool add_to_local_only,
|
const bool add_to_local_only,
|
||||||
const Graph* graph_with_debug_info) {
|
const StackTracesMap& stack_traces) {
|
||||||
bool is_first_ref = false;
|
bool is_first_ref = false;
|
||||||
{
|
{
|
||||||
mutex_lock l(cache_mu_);
|
mutex_lock l(cache_mu_);
|
||||||
@ -753,8 +753,7 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
|
|||||||
is_first_ref = registered_function->RefCountIsOne();
|
is_first_ref = registered_function->RefCountIsOne();
|
||||||
}
|
}
|
||||||
if (is_first_ref) {
|
if (is_first_ref) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef, stack_traces));
|
||||||
func_lib_def_.AddFunctionDef(fdef, graph_with_debug_info));
|
|
||||||
TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
|
TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
|
||||||
if (!add_to_local_only) {
|
if (!add_to_local_only) {
|
||||||
return MaybeRegisterFunctionRemotely(fdef);
|
return MaybeRegisterFunctionRemotely(fdef);
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user