Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Avijit 2018-08-12 16:21:41 -07:00
commit 9523a98466
2065 changed files with 79800 additions and 17381 deletions

View File

@ -22,6 +22,8 @@ organization for the purposes of conducting machine learning and deep neural
networks research. The system is general enough to be applicable in a wide networks research. The system is general enough to be applicable in a wide
variety of other domains, as well. variety of other domains, as well.
TensorFlow provides stable Python API and C APIs as well as without API backwards compatibility guarantee like C++, Go, Java, JavaScript and Swift.
Keep up to date with release announcements and security updates by Keep up to date with release announcements and security updates by
subscribing to subscribing to
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce). [announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
@ -81,13 +83,13 @@ The TensorFlow project strives to abide by generally accepted best practices in
| Build Type | Status | Artifacts | | Build Type | Status | Artifacts |
| --- | --- | --- | | --- | --- | --- |
| **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | | **Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [pypi](https://pypi.org/project/tf-nightly/) |
| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | | **Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
| **Linux XLA** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.png) | TBA | | **Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA |
| **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | | **MacOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [pypi](https://pypi.org/project/tf-nightly/) |
| **Windows CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.png) | [pypi](https://pypi.org/project/tf-nightly/) | | **Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [pypi](https://pypi.org/project/tf-nightly/) |
| **Windows GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | | **Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
| **Android** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.png) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | | **Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) |
### Community Supported Builds ### Community Supported Builds
@ -97,17 +99,20 @@ The TensorFlow project strives to abide by generally accepted best practices in
| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | | **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA |
| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | | **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA |
| **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | | **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA |
| **Linux CPU with Intel® MKL-DNN®** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | TBA | | **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) |
| **Linux CPU with Intel® MKL-DNN** Python 2.7<br> **Linux CPU with Intel® MKL-DNN** Python 3.5<br> **Linux CPU with Intel® MKL-DNN** Python 3.6| ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)|[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)<br>[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)<br>[1.9.0 py3.6](https://storage.cloud.google.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) |
## For more information ## For more information
* [Tensorflow Blog](https://medium.com/tensorflow)
* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
* [Tensorflow Twitter](https://twitter.com/tensorflow)
* [TensorFlow Website](https://www.tensorflow.org) * [TensorFlow Website](https://www.tensorflow.org)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) * [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate. Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate.

View File

@ -19,7 +19,7 @@
* `tf.data`: * `tf.data`:
* `tf.contrib.data.group_by_reducer()` is now available via the public API. * `tf.contrib.data.group_by_reducer()` is now available via the public API.
* `tf.contrib.data.choose_from_datasets()` is now available via the public API. * `tf.contrib.data.choose_from_datasets()` is now available via the public API.
* Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`. * Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating `tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`.
* `tf.estimator`: * `tf.estimator`:
* `Estimator`s now use custom savers included in `EstimatorSpec` scaffolds for saving SavedModels during export. * `Estimator`s now use custom savers included in `EstimatorSpec` scaffolds for saving SavedModels during export.
* `EstimatorSpec` will now add a default prediction output for export if no `export_output` is provided, eliminating the need to explicitly include a `PredictOutput` object in the `model_fn` for simple use-cases. * `EstimatorSpec` will now add a default prediction output for export if no `export_output` is provided, eliminating the need to explicitly include a `PredictOutput` object in the `model_fn` for simple use-cases.

View File

@ -451,11 +451,6 @@ filegroup(
), ),
) )
filegroup(
name = "docs_src",
data = glob(["docs_src/**/*.md"]),
)
cc_library( cc_library(
name = "grpc", name = "grpc",
deps = select({ deps = select({
@ -599,6 +594,7 @@ exports_files(
gen_api_init_files( gen_api_init_files(
name = "tensorflow_python_api_gen", name = "tensorflow_python_api_gen",
srcs = ["api_template.__init__.py"], srcs = ["api_template.__init__.py"],
api_version = 1,
root_init_template = "api_template.__init__.py", root_init_template = "api_template.__init__.py",
) )

View File

@ -1619,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
TF_DeleteFunction(func1); TF_DeleteFunction(func1);
} }
// This test only works when the TF build includes XLA compiler. One way to set
// this up is via bazel build option "--define with_xla_support=true".
//
// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
// something like TENSORFLOW_CAPI_USE_XLA.
#ifdef TENSORFLOW_EAGER_USE_XLA
TEST_F(CApiFunctionTest, StatelessIf_XLA) {
TF_Function* func;
const std::string funcName = "BranchFunc";
DefineFunction(funcName.c_str(), &func);
TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* feed = Placeholder(host_graph_, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_OperationDescription* desc =
TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
TF_AddInput(desc, {true_cond, 0});
TF_Output inputs[] = {{feed, 0}};
TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_SetAttrType(desc, "Tcond", TF_BOOL);
TF_DataType inputType = TF_INT32;
TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
TF_SetDevice(desc, "/device:XLA_CPU:0");
auto op = TF_FinishOperation(desc, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
ASSERT_NE(op, nullptr);
// Create a session for this graph.
CSession csession(host_graph_, s_, /*use_XLA*/ true);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
// Run the graph.
csession.SetInputs({{feed, Int32Tensor(17)}});
csession.SetOutputs({op});
csession.Run(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Tensor* out = csession.output_tensor(0);
ASSERT_TRUE(out != nullptr);
EXPECT_EQ(TF_INT32, TF_TensorType(out));
EXPECT_EQ(0, TF_NumDims(out)); // scalar
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
int32* output_contents = static_cast<int32*>(TF_TensorData(out));
EXPECT_EQ(-17, *output_contents);
// Clean up
csession.CloseAndDelete(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_DeleteFunction(func);
}
#endif // TENSORFLOW_EAGER_USE_XLA
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -26,6 +26,10 @@ limitations under the License.
using tensorflow::GraphDef; using tensorflow::GraphDef;
using tensorflow::NodeDef; using tensorflow::NodeDef;
static void BoolDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<bool*>(data);
}
static void Int32Deallocator(void* data, size_t, void* arg) { static void Int32Deallocator(void* data, size_t, void* arg) {
delete[] static_cast<int32_t*>(data); delete[] static_cast<int32_t*>(data);
} }
@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data); delete[] static_cast<float*>(data);
} }
TF_Tensor* BoolTensor(bool v) {
const int num_bytes = sizeof(bool);
bool* values = new bool[1];
values[0] = v;
return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator,
nullptr);
}
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
int64_t num_values = 1; int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) { for (int i = 0; i < num_dims; ++i) {
@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
return op; return op;
} }
TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
const char* name) {
unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor);
return Const(tensor.get(), graph, s, name);
}
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name) { const char* name) {
unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);

View File

@ -31,6 +31,8 @@ using ::tensorflow::string;
typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
unique_tensor_ptr; unique_tensor_ptr;
TF_Tensor* BoolTensor(int32_t v);
// Create a tensor with values of type TF_INT8 provided by `values`. // Create a tensor with values of type TF_INT8 provided by `values`.
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
const char* name = "const"); const char* name = "const");
TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar");
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar"); const char* name = "scalar");

View File

@ -110,7 +110,7 @@ tensorflow::Status GetAllRemoteDevices(
tensorflow::Status CreateRemoteContexts( tensorflow::Status CreateRemoteContexts(
const std::vector<string>& remote_workers, int64 rendezvous_id, const std::vector<string>& remote_workers, int64 rendezvous_id,
const tensorflow::ServerDef& server_def, int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) { tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
for (int i = 0; i < remote_workers.size(); i++) { for (int i = 0; i < remote_workers.size(); i++) {
@ -129,6 +129,7 @@ tensorflow::Status CreateRemoteContexts(
request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task); request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async); request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
auto* eager_client = remote_eager_workers->GetClient(remote_worker); auto* eager_client = remote_eager_workers->GetClient(remote_worker);
if (eager_client == nullptr) { if (eager_client == nullptr) {
return tensorflow::errors::Internal( return tensorflow::errors::Internal(
@ -151,7 +152,8 @@ tensorflow::Status CreateRemoteContexts(
} }
tensorflow::Status UpdateTFE_ContextWithServerDef( tensorflow::Status UpdateTFE_ContextWithServerDef(
const tensorflow::ServerDef& server_def, TFE_Context* ctx) { int keep_alive_secs, const tensorflow::ServerDef& server_def,
TFE_Context* ctx) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead, // server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error // we log the error, and then return to allow the user to see the error
@ -202,8 +204,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// Initialize remote eager workers. // Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts; tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, rendezvous_id, server_def, remote_eager_workers.get(), remote_workers, rendezvous_id, keep_alive_secs, server_def,
ctx->context.Async(), &remote_contexts)); remote_eager_workers.get(), ctx->context.Async(), &remote_contexts));
tensorflow::RemoteRendezvous* r = tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
@ -222,9 +224,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
auto* device_mgr = grpc_server->worker_env()->device_mgr; auto* device_mgr = grpc_server->worker_env()->device_mgr;
ctx->context.InitializeRemote( ctx->context.InitializeRemote(std::move(server),
std::move(server), std::move(remote_eager_workers), std::move(remote_eager_workers),
std::move(remote_device_mgr), remote_contexts, r, device_mgr); std::move(remote_device_mgr), remote_contexts,
r, device_mgr, keep_alive_secs);
return tensorflow::Status::OK(); return tensorflow::Status::OK();
#undef LOG_AND_RETURN_IF_ERROR #undef LOG_AND_RETURN_IF_ERROR
@ -288,6 +291,7 @@ void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
// Set server_def on the context, possibly updating it. // Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
int keep_alive_secs,
const void* proto, const void* proto,
size_t proto_len, size_t proto_len,
TF_Status* status) { TF_Status* status) {
@ -297,7 +301,8 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
"Invalid tensorflow.ServerDef protocol buffer"); "Invalid tensorflow.ServerDef protocol buffer");
return; return;
} }
status->status = UpdateTFE_ContextWithServerDef(server_def, ctx); status->status =
UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx);
} }
void TFE_ContextSetThreadLocalDevicePlacementPolicy( void TFE_ContextSetThreadLocalDevicePlacementPolicy(
@ -719,6 +724,10 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
} }
} // namespace } // namespace
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); }
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); }
namespace tensorflow { namespace tensorflow {
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value, const tensorflow::AttrValue& default_value,

View File

@ -124,6 +124,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
// If the following is set, all servers identified by the // If the following is set, all servers identified by the
// ServerDef must be up when the context is created. // ServerDef must be up when the context is created.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
int keep_alive_secs,
const void* proto, const void* proto,
size_t proto_len, size_t proto_len,
TF_Status* status); TF_Status* status);
@ -380,6 +381,16 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
TF_Buffer* buf, TF_Buffer* buf,
TF_Status* status); TF_Status* status);
// Some TF ops need a step container to be set to limit the lifetime of some
// resources (mostly TensorArray and Stack, used in while loop gradients in
// graph mode). Calling this on a context tells it to start a step.
TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx);
// Ends a step. When there is no active step (that is, every started step has
// been ended) step containers will be cleared. Note: it is not safe to call
// TFE_ContextEndStep while ops which rely on the step container may be running.
TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -151,7 +151,7 @@ void TestRemoteExecute(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts); TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
@ -239,7 +239,7 @@ void TestRemoteExecuteSilentCopies(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts); TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
@ -371,7 +371,7 @@ void TestRemoteExecuteChangeServerDef(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts); TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char remote_device_name[] = const char remote_device_name[] =
@ -397,7 +397,7 @@ void TestRemoteExecuteChangeServerDef(bool async) {
ASSERT_TRUE(s.ok()) << s.error_message(); ASSERT_TRUE(s.ok()) << s.error_message();
ASSERT_TRUE(worker_server->Start().ok()); ASSERT_TRUE(worker_server->Start().ok());
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Create a new tensor_handle. // Create a new tensor_handle.

View File

@ -379,9 +379,11 @@ tf_cc_test(
srcs = ["gradients/math_grad_test.cc"], srcs = ["gradients/math_grad_test.cc"],
deps = [ deps = [
":cc_ops", ":cc_ops",
":client_session",
":grad_op_registry", ":grad_op_registry",
":grad_testutil", ":grad_testutil",
":gradient_checker", ":gradient_checker",
":gradients",
":math_grad", ":math_grad",
":testutil", ":testutil",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",

View File

@ -120,6 +120,24 @@ Status SplitGrad(const Scope& scope, const Operation& op,
} }
REGISTER_GRADIENT_OP("Split", SplitGrad); REGISTER_GRADIENT_OP("Split", SplitGrad);
Status FillGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// y = fill(fill_shape, x)
// No gradient returned for the fill_shape argument.
grad_outputs->push_back(NoGradient());
// The gradient for x (which must be a scalar) is just the sum of
// all the gradients from the shape it fills.
// We use ReduceSum to implement this, which needs an argument providing
// the indices of all the dimensions of the incoming gradient.
// grad(x) = reduce_sum(grad(y), [0..rank(grad(y))])
auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]),
Const(scope, 1));
grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims));
return scope.status();
}
REGISTER_GRADIENT_OP("Fill", FillGrad);
Status DiagGrad(const Scope& scope, const Operation& op, Status DiagGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs, const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) { std::vector<Output>* grad_outputs) {

View File

@ -108,6 +108,14 @@ TEST_F(ArrayGradTest, SplitGrad) {
RunTest({x}, {x_shape}, y.output, {y_shape, y_shape}); RunTest({x}, {x_shape}, y.output, {y_shape, y_shape});
} }
TEST_F(ArrayGradTest, FillGrad) {
TensorShape x_shape({});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
TensorShape y_shape({2, 5, 3});
auto y = Fill(scope_, {2, 5, 3}, x);
RunTest(x, x_shape, y, y_shape);
}
TEST_F(ArrayGradTest, DiagGrad) { TEST_F(ArrayGradTest, DiagGrad) {
TensorShape x_shape({5, 2}); TensorShape x_shape({5, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));

View File

@ -441,6 +441,22 @@ Status RealDivGrad(const Scope& scope, const Operation& op,
} }
REGISTER_GRADIENT_OP("RealDiv", RealDivGrad); REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
Status UnsafeDivGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto x_1 = ConjugateHelper(scope, op.input(0));
auto x_2 = ConjugateHelper(scope, op.input(1));
// y = x_1 / x_2
// dy/dx_1 = 1/x_2
// dy/dx_2 = -x_1/x_2^2
auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2);
auto gx_2 =
Mul(scope, grad_inputs[0],
UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2));
return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
}
REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad);
Status SquaredDifferenceGrad(const Scope& scope, const Operation& op, Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs, const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) { std::vector<Output>* grad_outputs) {
@ -1007,6 +1023,26 @@ Status ProdGrad(const Scope& scope, const Operation& op,
} }
REGISTER_GRADIENT_OP("Prod", ProdGrad); REGISTER_GRADIENT_OP("Prod", ProdGrad);
Status SegmentSumGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// The SegmentSum operation sums segments of the Tensor that have the same
// index in the segment_ids parameter.
// i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1]
// will produce [2 + 3 + 4, 5] = [9, 5]
// The gradient that will flow back to the gather operation will look like
// [x1, x2], it will have the same shape as the output of the SegmentSum
// operation. The differentiation step of the SegmentSum operation just
// broadcast the gradient in order to retrieve the z's shape.
// dy/dz = [x1, x1, x1, x2]
grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1)));
// stop propagation along segment_ids
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad);
// MatMulGrad helper function used to compute two MatMul operations // MatMulGrad helper function used to compute two MatMul operations
// based on input matrix transposition combinations. // based on input matrix transposition combinations.
Status MatMulGradHelper(const Scope& scope, const bool is_batch, Status MatMulGradHelper(const Scope& scope, const bool is_batch,

View File

@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/grad_op_registry.h" #include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradient_checker.h" #include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/gradients/grad_testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
@ -42,9 +44,11 @@ using ops::Placeholder;
using ops::Pow; using ops::Pow;
using ops::Prod; using ops::Prod;
using ops::RealDiv; using ops::RealDiv;
using ops::SegmentSum;
using ops::SquaredDifference; using ops::SquaredDifference;
using ops::Sub; using ops::Sub;
using ops::Sum; using ops::Sum;
using ops::UnsafeDiv;
// TODO(andydavis) Test gradient function against numeric gradients output. // TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions // TODO(andydavis) As more gradients are added move common test functions
@ -850,6 +854,36 @@ TEST_F(NaryGradTest, RealDiv) {
RunTest({x}, {x_shape}, {y}, {x_shape}); RunTest({x}, {x_shape}, {y}, {x_shape});
} }
TEST_F(NaryGradTest, UnsafeDiv) {
{
TensorShape x_shape({3, 2, 5});
const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
// Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
// division errors in the numeric estimator used by the gradient checker.
const auto y = UnsafeDiv(
scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
RunTest({x}, {x_shape}, {y}, {x_shape});
}
{
// Return 0 gradient (rather than NaN) for division by zero.
const auto x = Placeholder(scope_, DT_FLOAT);
const auto zero = Const<float>(scope_, 0.0);
const auto y = UnsafeDiv(scope_, x, zero);
std::vector<Output> grad_outputs;
TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs));
ClientSession session(scope_);
std::vector<Tensor> grad_result;
TF_EXPECT_OK(
session.Run({{x, {-3.0f, 0.0f, 3.0f}}}, grad_outputs, &grad_result));
EXPECT_EQ(grad_result.size(), 1);
EXPECT_EQ(grad_result[0].NumElements(), 3);
EXPECT_EQ(grad_result[0].flat<float>()(0), 0.0f);
EXPECT_EQ(grad_result[0].flat<float>()(1), 0.0f);
EXPECT_EQ(grad_result[0].flat<float>()(2), 0.0f);
}
}
TEST_F(NaryGradTest, SquaredDifference) { TEST_F(NaryGradTest, SquaredDifference) {
TensorShape x1_shape({3, 2, 5}); TensorShape x1_shape({3, 2, 5});
TensorShape x2_shape({2, 5}); TensorShape x2_shape({2, 5});
@ -898,5 +932,14 @@ TEST_F(NaryGradTest, Prod) {
RunTest({x}, {x_shape}, {y}, {y_shape}); RunTest({x}, {x_shape}, {y}, {y_shape});
} }
TEST_F(NaryGradTest, SegmentSum) {
TensorShape x_shape({3, 4});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
auto y = SegmentSum(scope_, x, {0, 0, 1});
// the sum is always on the first dimension
TensorShape y_shape({2, 4});
RunTest({x}, {x_shape}, {y}, {y_shape});
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -170,7 +170,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
variables_directory, MetaFilename(kSavedModelVariablesFilename)); variables_directory, MetaFilename(kSavedModelVariablesFilename));
if (!Env::Default()->FileExists(variables_index_path).ok()) { if (!Env::Default()->FileExists(variables_index_path).ok()) {
LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " LOG(INFO) << "The specified SavedModel has no variables; no checkpoints "
"were restored."; "were restored. File does not exist: "
<< variables_index_path;
return Status::OK(); return Status::OK();
} }
const string variables_path = const string variables_path =

View File

@ -48,6 +48,7 @@ cc_library(
"//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service/cpu:buffer_info_util",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/str_util.h" #include "tensorflow/compiler/tf2xla/str_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
@ -36,6 +37,8 @@ namespace tfcompile {
namespace { namespace {
using BufferInfo = cpu_function_runtime::BufferInfo;
bool IsAlpha(char c) { bool IsAlpha(char c) {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
} }
@ -85,27 +88,36 @@ Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
return Status::OK(); return Status::OK();
} }
// total_buffer_bytes returns the sum of each size in `sizes`, skipping -1 // Returns the sum of the size of each buffer in `buffer_infos`.
// values. There are `n` entries in `sizes`. size_t TotalBufferBytes(const std::vector<BufferInfo>& buffer_infos) {
size_t total_buffer_bytes(const intptr_t* sizes, size_t n) { return std::accumulate(buffer_infos.begin(), buffer_infos.end(), size_t{0},
size_t total = 0; [](size_t size, const BufferInfo& buffer_info) {
for (size_t i = 0; i < n; ++i) { return size + buffer_info.size();
if (sizes[i] != -1) { });
total += sizes[i];
}
}
return total;
} }
// Fills in arg_sizes with the byte size of each positional arg. // Returns a vector of BufferInfo instances in `buffer_infos` that are entry
Status ComputeArgSizes(const CompileResult& compile_result, // parameter buffers.
std::vector<int64>* arg_sizes) { std::vector<BufferInfo> ExtractEntryParamBufferInfos(
const xla::ProgramShape& ps = compile_result.program_shape; const std::vector<BufferInfo>& buffer_infos) {
for (int i = 0; i < ps.parameters_size(); ++i) { std::vector<BufferInfo> result;
arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( std::copy_if(buffer_infos.begin(), buffer_infos.end(),
ps.parameters(i), compile_result.pointer_size)); std::back_inserter(result), [](const BufferInfo& buffer_info) {
return buffer_info.is_entry_parameter();
});
return result;
} }
return Status::OK();
// Returns a vector of BufferInfo instances in `buffer_infos` that are temp
// buffers.
std::vector<BufferInfo> ExtractTempBufferInfos(
const std::vector<BufferInfo>& buffer_infos) {
std::vector<BufferInfo> result;
std::copy_if(buffer_infos.begin(), buffer_infos.end(),
std::back_inserter(result), [](const BufferInfo& buffer_info) {
return buffer_info.is_temp_buffer();
});
return result;
} }
// Add (from,to) rewrite pairs based on the given shape. These rewrite pairs // Add (from,to) rewrite pairs based on the given shape. These rewrite pairs
@ -278,6 +290,25 @@ Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
return Status::OK(); return Status::OK();
} }
// Returns a list of C++ expressions that, when executed, will construct the
// BufferInfo instances in `buffer_infos`.
std::vector<string> BufferInfosToCppExpression(
const std::vector<BufferInfo>& buffer_infos) {
std::vector<string> buffer_infos_as_strings;
std::transform(buffer_infos.begin(), buffer_infos.end(),
std::back_inserter(buffer_infos_as_strings),
[](const BufferInfo& buffer_info) {
std::pair<uint64, uint64> encoded = buffer_info.Encode();
string encoded_second_as_str =
encoded.second == ~0ULL
? "~0ULL"
: strings::StrCat(encoded.second, "ULL");
return strings::StrCat(
"::tensorflow::cpu_function_runtime::BufferInfo({",
encoded.first, "ULL, ", encoded_second_as_str, "})");
});
return buffer_infos_as_strings;
}
} // namespace } // namespace
Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
@ -286,29 +317,35 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(ValidateConfig(config)); TF_RETURN_IF_ERROR(ValidateConfig(config));
TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
const int64 result_index = compile_result.aot->result_buffer_index(); const int64 result_index = compile_result.aot->result_buffer_index();
const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes(); const std::vector<BufferInfo>& buffer_infos =
if (result_index < 0 || result_index >= temp_sizes.size()) { compile_result.aot->buffer_infos();
const std::vector<int32> arg_index_table =
::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
std::vector<string> buffer_infos_as_strings =
BufferInfosToCppExpression(buffer_infos);
if (result_index < 0 || result_index >= buffer_infos.size()) {
return errors::InvalidArgument("result index: ", result_index, return errors::InvalidArgument("result index: ", result_index,
" is outside the range of temp sizes: [0,", " is outside the range of temp sizes: [0,",
temp_sizes.size(), ")"); buffer_infos.size(), ")");
} }
// Compute sizes and generate methods. // Compute sizes and generate methods.
std::vector<int64> arg_sizes; std::vector<BufferInfo> buffer_infos_for_args =
TF_RETURN_IF_ERROR(ComputeArgSizes(compile_result, &arg_sizes)); ExtractEntryParamBufferInfos(buffer_infos);
std::vector<BufferInfo> buffer_infos_for_temps =
ExtractTempBufferInfos(buffer_infos);
const xla::ProgramShape& ps = compile_result.program_shape; const xla::ProgramShape& ps = compile_result.program_shape;
string methods_arg, methods_result; string methods_arg, methods_result;
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
const std::vector<intptr_t> iarg(arg_sizes.begin(), arg_sizes.end()); const size_t arg_bytes_aligned = cpu_function_runtime::AlignedBufferBytes(
const std::vector<intptr_t> itemp(temp_sizes.begin(), temp_sizes.end()); buffer_infos_for_args.data(), buffer_infos_for_args.size(),
const size_t arg_bytes_aligned = /*allocate_entry_params=*/true);
cpu_function_runtime::AlignedBufferBytes(iarg.data(), iarg.size()); const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args);
const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size()); const size_t temp_bytes_aligned = cpu_function_runtime::AlignedBufferBytes(
const size_t temp_bytes_aligned = buffer_infos_for_temps.data(), buffer_infos_for_temps.size(),
cpu_function_runtime::AlignedBufferBytes(itemp.data(), itemp.size()); /*allocate_entry_params=*/true);
const size_t temp_bytes_total = const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps);
total_buffer_bytes(itemp.data(), itemp.size());
// Create rewrite strings for namespace start and end. // Create rewrite strings for namespace start and end.
string ns_start; string ns_start;
@ -343,8 +380,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
// calling HloProfilePrinter::profile_counters_size. // calling HloProfilePrinter::profile_counters_size.
const string assign_profile_counters_size = const string assign_profile_counters_size =
opts.gen_hlo_profile_printer_data opts.gen_hlo_profile_printer_data
? "data->profile_counters_size = " ? "data->set_profile_counters_size("
"data->hlo_profile_printer_data->profile_counters_size();" "data->hlo_profile_printer_data()->profile_counters_size());"
: ""; : "";
// Use a poor-man's text templating mechanism; first populate the full header // Use a poor-man's text templating mechanism; first populate the full header
@ -414,9 +451,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static constexpr size_t kNumArgs = {{ARG_NUM}}; static constexpr size_t kNumArgs = {{ARG_NUM}};
// Byte size of each argument buffer. There are kNumArgs entries. // Byte size of each argument buffer. There are kNumArgs entries.
static const intptr_t* ArgSizes() { static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
static constexpr intptr_t kArgSizes[kNumArgs] = {{{ARG_SIZES}}}; return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
return kArgSizes;
} }
// Returns static data used to create an XlaCompiledCpuFunction. // Returns static data used to create an XlaCompiledCpuFunction.
@ -424,16 +460,16 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
XlaCompiledCpuFunction::StaticData* data = XlaCompiledCpuFunction::StaticData* data =
new XlaCompiledCpuFunction::StaticData; new XlaCompiledCpuFunction::StaticData;
data->raw_function = {{ENTRY}}; data->set_raw_function({{ENTRY}});
data->arg_sizes = ArgSizes(); data->set_buffer_infos(BufferInfos());
data->num_args = kNumArgs; data->set_num_buffers(kNumBuffers);
data->temp_sizes = TempSizes(); data->set_arg_index_table(ArgIndexToBufferIndex());
data->num_temps = kNumTemps; data->set_num_args(kNumArgs);
data->result_index = kResultIndex; data->set_result_index(kResultIndex);
data->arg_names = StaticArgNames(); data->set_arg_names(StaticArgNames());
data->result_names = StaticResultNames(); data->set_result_names(StaticResultNames());
data->program_shape = StaticProgramShape(); data->set_program_shape(StaticProgramShape());
data->hlo_profile_printer_data = StaticHloProfilePrinterData(); data->set_hlo_profile_printer_data(StaticHloProfilePrinterData());
{{ASSIGN_PROFILE_COUNTERS_SIZE}} {{ASSIGN_PROFILE_COUNTERS_SIZE}}
return data; return data;
}(); }();
@ -482,17 +518,27 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{{METHODS_RESULT}} {{METHODS_RESULT}}
private: private:
// Number of result and temporary buffers for the compiled computation. // Number of buffers for the compiled computation.
static constexpr size_t kNumTemps = {{TEMP_NUM}}; static constexpr size_t kNumBuffers = {{NUM_BUFFERS}};
static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() {
static const ::tensorflow::cpu_function_runtime::BufferInfo
kBufferInfos[kNumBuffers] = {
{{BUFFER_INFOS_AS_STRING}}
};
return kBufferInfos;
}
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
{{ARG_INDEX_TABLE}}
};
return kArgIndexToBufferIndex;
}
// The 0-based index of the result tuple in the temporary buffers. // The 0-based index of the result tuple in the temporary buffers.
static constexpr size_t kResultIndex = {{RESULT_INDEX}}; static constexpr size_t kResultIndex = {{RESULT_INDEX}};
// Byte size of each result / temporary buffer. There are kNumTemps entries.
static const intptr_t* TempSizes() {
static constexpr intptr_t kTempSizes[kNumTemps] = {{{TEMP_SIZES}}};
return kTempSizes;
}
// Array of names of each positional argument, terminated by nullptr. // Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {{ARG_NAMES_CODE}} static const char** StaticArgNames() {{ARG_NAMES_CODE}}
@ -523,8 +569,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)}, {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)},
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
{"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())}, {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())},
{"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name}, {"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}", {"{{DECLS_FROM_OBJ_FILE}}",
@ -546,8 +592,9 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{RESULT_NAMES_CODE}}", result_names_code},
{"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
{"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())},
{"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}}; {"{{BUFFER_INFOS_AS_STRING}}",
str_util::Join(buffer_infos_as_strings, ",\n")}};
str_util::ReplaceAllPairs(header, rewrites); str_util::ReplaceAllPairs(header, rewrites);
return Status::OK(); return Status::OK();
} }

View File

@ -32,6 +32,8 @@ namespace tensorflow {
namespace tfcompile { namespace tfcompile {
namespace { namespace {
using ::tensorflow::cpu_function_runtime::BufferInfo;
void ExpectErrorContains(const Status& status, StringPiece str) { void ExpectErrorContains(const Status& status, StringPiece str) {
EXPECT_NE(Status::OK(), status); EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
@ -171,8 +173,14 @@ TEST(CodegenTest, Golden) {
fetch->mutable_id()->set_node_name("fetch0"); fetch->mutable_id()->set_node_name("fetch0");
fetch->set_name("myfetch"); fetch->set_name("myfetch");
CompileResult compile_result; CompileResult compile_result;
compile_result.aot.reset( compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {})); {},
{BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
BufferInfo::MakeTempBuffer(2),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
5, {}));
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
{ {
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),

View File

@ -65,9 +65,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
static constexpr size_t kNumArgs = 2; static constexpr size_t kNumArgs = 2;
// Byte size of each argument buffer. There are kNumArgs entries. // Byte size of each argument buffer. There are kNumArgs entries.
static const intptr_t* ArgSizes() { static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96}; return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
return kArgSizes;
} }
// Returns static data used to create an XlaCompiledCpuFunction. // Returns static data used to create an XlaCompiledCpuFunction.
@ -75,16 +74,16 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
XlaCompiledCpuFunction::StaticData* data = XlaCompiledCpuFunction::StaticData* data =
new XlaCompiledCpuFunction::StaticData; new XlaCompiledCpuFunction::StaticData;
data->raw_function = entry_point; data->set_raw_function(entry_point);
data->arg_sizes = ArgSizes(); data->set_buffer_infos(BufferInfos());
data->num_args = kNumArgs; data->set_num_buffers(kNumBuffers);
data->temp_sizes = TempSizes(); data->set_arg_index_table(ArgIndexToBufferIndex());
data->num_temps = kNumTemps; data->set_num_args(kNumArgs);
data->result_index = kResultIndex; data->set_result_index(kResultIndex);
data->arg_names = StaticArgNames(); data->set_arg_names(StaticArgNames());
data->result_names = StaticResultNames(); data->set_result_names(StaticResultNames());
data->program_shape = StaticProgramShape(); data->set_program_shape(StaticProgramShape());
data->hlo_profile_printer_data = StaticHloProfilePrinterData(); data->set_hlo_profile_printer_data(StaticHloProfilePrinterData());
return data; return data;
}(); }();
@ -215,17 +214,32 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
} }
private: private:
// Number of result and temporary buffers for the compiled computation. // Number of buffers for the compiled computation.
static constexpr size_t kNumTemps = 6; static constexpr size_t kNumBuffers = 6;
static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() {
static const ::tensorflow::cpu_function_runtime::BufferInfo
kBufferInfos[kNumBuffers] = {
::tensorflow::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
};
return kBufferInfos;
}
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
1, 3
};
return kArgIndexToBufferIndex;
}
// The 0-based index of the result tuple in the temporary buffers. // The 0-based index of the result tuple in the temporary buffers.
static constexpr size_t kResultIndex = 5; static constexpr size_t kResultIndex = 5;
// Byte size of each result / temporary buffer. There are kNumTemps entries.
static const intptr_t* TempSizes() {
static constexpr intptr_t kTempSizes[kNumTemps] = {1, -1, 2, -1, 3, 120};
return kTempSizes;
}
// Array of names of each positional argument, terminated by nullptr. // Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() { static const char** StaticArgNames() {
static const char* kNames[] = {"myfeed", nullptr}; static const char* kNames[] = {"myfeed", nullptr};

View File

@ -51,11 +51,9 @@ namespace tensorflow {
namespace tfcompile { namespace tfcompile {
namespace { namespace {
void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) { void zero_buffers(XlaCompiledCpuFunction* computation) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < computation->num_args(); ++i) {
if (sizes[i] != -1) { memset(computation->arg_data(i), 0, computation->arg_size(i));
memset(bufs[i], 0, sizes[i]);
}
} }
} }
@ -66,7 +64,7 @@ TEST(TEST_NAME, NoCrash) {
CPP_CLASS computation; CPP_CLASS computation;
computation.set_thread_pool(&device); computation.set_thread_pool(&device);
zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs); zero_buffers(&computation);
EXPECT_TRUE(computation.Run()); EXPECT_TRUE(computation.Run());
} }
@ -80,7 +78,7 @@ void BM_NAME(int iters) {
CPP_CLASS computation; CPP_CLASS computation;
computation.set_thread_pool(&device); computation.set_thread_pool(&device);
zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs); zero_buffers(&computation);
testing::StartTiming(); testing::StartTiming();
while (--iters) { while (--iters) {

View File

@ -44,8 +44,8 @@ using ::testing::IsSupersetOf;
TEST(TFCompileTest, Add) { TEST(TFCompileTest, Add) {
AddComp add; AddComp add;
EXPECT_EQ(add.arg0_data(), add.args()[0]); EXPECT_EQ(add.arg0_data(), add.arg_data(0));
EXPECT_EQ(add.arg1_data(), add.args()[1]); EXPECT_EQ(add.arg1_data(), add.arg_data(1));
add.arg0() = 1; add.arg0() = 1;
add.arg1() = 2; add.arg1() = 2;
@ -67,10 +67,10 @@ TEST(TFCompileTest, Add) {
EXPECT_EQ(add_const.error_msg(), ""); EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 123); EXPECT_EQ(add_const.arg0(), 123);
EXPECT_EQ(add_const.arg0_data()[0], 123); EXPECT_EQ(add_const.arg0_data()[0], 123);
EXPECT_EQ(add_const.arg0_data(), add.args()[0]); EXPECT_EQ(add_const.arg0_data(), add.arg_data(0));
EXPECT_EQ(add_const.arg1(), 456); EXPECT_EQ(add_const.arg1(), 456);
EXPECT_EQ(add_const.arg1_data()[0], 456); EXPECT_EQ(add_const.arg1_data()[0], 456);
EXPECT_EQ(add_const.arg1_data(), add.args()[1]); EXPECT_EQ(add_const.arg1_data(), add.arg_data(1));
EXPECT_EQ(add_const.result0(), 579); EXPECT_EQ(add_const.result0(), 579);
EXPECT_EQ(add_const.result0_data()[0], 579); EXPECT_EQ(add_const.result0_data()[0], 579);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
@ -85,8 +85,8 @@ TEST(TFCompileTest, Add_SetArg) {
int32 arg_y = 32; int32 arg_y = 32;
add.set_arg0_data(&arg_x); add.set_arg0_data(&arg_x);
add.set_arg1_data(&arg_y); add.set_arg1_data(&arg_y);
EXPECT_EQ(add.arg0_data(), add.args()[0]); EXPECT_EQ(add.arg0_data(), add.arg_data(0));
EXPECT_EQ(add.arg1_data(), add.args()[1]); EXPECT_EQ(add.arg1_data(), add.arg_data(1));
EXPECT_TRUE(add.Run()); EXPECT_TRUE(add.Run());
EXPECT_EQ(add.error_msg(), ""); EXPECT_EQ(add.error_msg(), "");
@ -97,7 +97,7 @@ TEST(TFCompileTest, Add_SetArg) {
TEST(TFCompileTest, AddWithCkpt) { TEST(TFCompileTest, AddWithCkpt) {
AddWithCkptComp add; AddWithCkptComp add;
EXPECT_EQ(add.arg0_data(), add.args()[0]); EXPECT_EQ(add.arg0_data(), add.arg_data(0));
add.arg0() = 1; add.arg0() = 1;
EXPECT_TRUE(add.Run()); EXPECT_TRUE(add.Run());
@ -117,7 +117,7 @@ TEST(TFCompileTest, AddWithCkpt) {
EXPECT_EQ(add_const.error_msg(), ""); EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 111); EXPECT_EQ(add_const.arg0(), 111);
EXPECT_EQ(add_const.arg0_data()[0], 111); EXPECT_EQ(add_const.arg0_data()[0], 111);
EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]); EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
EXPECT_EQ(add_const.result0(), 153); EXPECT_EQ(add_const.result0(), 153);
EXPECT_EQ(add_const.result0_data()[0], 153); EXPECT_EQ(add_const.result0_data()[0], 153);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
@ -125,7 +125,7 @@ TEST(TFCompileTest, AddWithCkpt) {
TEST(TFCompileTest, AddWithCkptSaver) { TEST(TFCompileTest, AddWithCkptSaver) {
AddWithCkptSaverComp add; AddWithCkptSaverComp add;
EXPECT_EQ(add.arg0_data(), add.args()[0]); EXPECT_EQ(add.arg0_data(), add.arg_data(0));
add.arg0() = 1; add.arg0() = 1;
EXPECT_TRUE(add.Run()); EXPECT_TRUE(add.Run());
@ -145,7 +145,7 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.error_msg(), ""); EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 111); EXPECT_EQ(add_const.arg0(), 111);
EXPECT_EQ(add_const.arg0_data()[0], 111); EXPECT_EQ(add_const.arg0_data()[0], 111);
EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]); EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
EXPECT_EQ(add_const.result0(), 153); EXPECT_EQ(add_const.result0(), 153);
EXPECT_EQ(add_const.result0_data()[0], 153); EXPECT_EQ(add_const.result0_data()[0], 153);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
@ -153,9 +153,9 @@ TEST(TFCompileTest, AddWithCkptSaver) {
TEST(TFCompileTest, Cond) { TEST(TFCompileTest, Cond) {
CondComp cond; CondComp cond;
EXPECT_EQ(cond.arg0_data(), cond.args()[0]); EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
EXPECT_EQ(cond.arg1_data(), cond.args()[1]); EXPECT_EQ(cond.arg1_data(), cond.arg_data(1));
EXPECT_EQ(cond.arg2_data(), cond.args()[2]); EXPECT_EQ(cond.arg2_data(), cond.arg_data(2));
cond.arg1() = 10; cond.arg1() = 10;
cond.arg2() = 20; cond.arg2() = 20;
{ {
@ -178,8 +178,8 @@ TEST(TFCompileTest, Cond) {
TEST(TFCompileTest, Gather) { TEST(TFCompileTest, Gather) {
GatherComp gather; GatherComp gather;
EXPECT_EQ(gather.arg0_data(), gather.args()[0]); EXPECT_EQ(gather.arg0_data(), gather.arg_data(0));
EXPECT_EQ(gather.arg1_data(), gather.args()[1]); EXPECT_EQ(gather.arg1_data(), gather.arg_data(1));
// Successful gather. // Successful gather.
{ {
@ -202,12 +202,12 @@ TEST(TFCompileTest, Gather) {
EXPECT_EQ(gather_const.arg0(i), params[i]); EXPECT_EQ(gather_const.arg0(i), params[i]);
EXPECT_EQ(gather_const.arg0_data()[i], params[i]); EXPECT_EQ(gather_const.arg0_data()[i], params[i]);
} }
EXPECT_EQ(gather_const.arg0_data(), gather_const.args()[0]); EXPECT_EQ(gather_const.arg0_data(), gather_const.arg_data(0));
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
EXPECT_EQ(gather_const.arg1(i), indices[i]); EXPECT_EQ(gather_const.arg1(i), indices[i]);
EXPECT_EQ(gather_const.arg1_data()[i], indices[i]); EXPECT_EQ(gather_const.arg1_data()[i], indices[i]);
} }
EXPECT_EQ(gather_const.arg1_data(), gather_const.args()[1]); EXPECT_EQ(gather_const.arg1_data(), gather_const.arg_data(1));
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
EXPECT_EQ(gather_const.result0(i), results[i]); EXPECT_EQ(gather_const.result0(i), results[i]);
EXPECT_EQ(gather_const.result0_data()[i], results[i]); EXPECT_EQ(gather_const.result0_data()[i], results[i]);
@ -222,8 +222,8 @@ TEST(TFCompileTest, MatMul2) {
foo::bar::MatMulComp matmul; foo::bar::MatMulComp matmul;
matmul.set_thread_pool(&device); matmul.set_thread_pool(&device);
EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]); EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]); EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
// Test using the argN() methods. // Test using the argN() methods.
{ {
@ -271,12 +271,12 @@ TEST(TFCompileTest, MatMul2) {
EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]); EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]);
EXPECT_EQ(matmul_const.arg0_data()[i], args[i]); EXPECT_EQ(matmul_const.arg0_data()[i], args[i]);
} }
EXPECT_EQ(matmul_const.arg0_data(), matmul.args()[0]); EXPECT_EQ(matmul_const.arg0_data(), matmul.arg_data(0));
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]); EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]);
EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]); EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]);
} }
EXPECT_EQ(matmul_const.arg1_data(), matmul.args()[1]); EXPECT_EQ(matmul_const.arg1_data(), matmul.arg_data(1));
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]); EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]);
EXPECT_EQ(matmul_const.result0_data()[i], results[i]); EXPECT_EQ(matmul_const.result0_data()[i], results[i]);
@ -300,8 +300,8 @@ TEST(TFCompileTest, MatMul2_SetArg) {
float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}}; float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}};
matmul.set_arg0_data(&arg0); matmul.set_arg0_data(&arg0);
matmul.set_arg1_data(&arg1); matmul.set_arg1_data(&arg1);
EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]); EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]); EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
EXPECT_TRUE(matmul.Run()); EXPECT_TRUE(matmul.Run());
EXPECT_EQ(matmul.error_msg(), ""); EXPECT_EQ(matmul.error_msg(), "");
@ -319,8 +319,8 @@ TEST(TFCompileTest, MatMulAndAdd1) {
MatMulAndAddComp muladd; MatMulAndAddComp muladd;
muladd.set_thread_pool(&device); muladd.set_thread_pool(&device);
EXPECT_EQ(muladd.arg0_data(), muladd.args()[0]); EXPECT_EQ(muladd.arg0_data(), muladd.arg_data(0));
EXPECT_EQ(muladd.arg1_data(), muladd.args()[1]); EXPECT_EQ(muladd.arg1_data(), muladd.arg_data(1));
// Test methods with positional args and results. // Test methods with positional args and results.
{ {
@ -346,12 +346,12 @@ TEST(TFCompileTest, MatMulAndAdd1) {
EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]); EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]);
EXPECT_EQ(muladd_const.arg0_data()[i], args[i]); EXPECT_EQ(muladd_const.arg0_data()[i], args[i]);
} }
EXPECT_EQ(muladd_const.arg0_data(), muladd.args()[0]); EXPECT_EQ(muladd_const.arg0_data(), muladd.arg_data(0));
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]); EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]);
EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]); EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]);
} }
EXPECT_EQ(muladd_const.arg1_data(), muladd.args()[1]); EXPECT_EQ(muladd_const.arg1_data(), muladd.arg_data(1));
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]); EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]);
EXPECT_EQ(muladd_const.result0_data()[i], results0[i]); EXPECT_EQ(muladd_const.result0_data()[i], results0[i]);
@ -387,12 +387,12 @@ TEST(TFCompileTest, MatMulAndAdd1) {
EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]); EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]);
EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]); EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]);
} }
EXPECT_EQ(muladd_const.arg_x_data(), muladd.args()[0]); EXPECT_EQ(muladd_const.arg_x_data(), muladd.arg_data(0));
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]); EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]);
EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]); EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]);
} }
EXPECT_EQ(muladd_const.arg_y_data(), muladd.args()[1]); EXPECT_EQ(muladd_const.arg_y_data(), muladd.arg_data(1));
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]); EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]);
EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]); EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]);
@ -407,8 +407,8 @@ TEST(TFCompileTest, MatMulAndAdd1) {
TEST(TFCompileTest, Function) { TEST(TFCompileTest, Function) {
// The function is equivalent to an addition // The function is equivalent to an addition
FunctionComp add_fn; FunctionComp add_fn;
EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]); EXPECT_EQ(add_fn.arg0_data(), add_fn.arg_data(0));
EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]); EXPECT_EQ(add_fn.arg1_data(), add_fn.arg_data(1));
add_fn.arg0() = 1; add_fn.arg0() = 1;
add_fn.arg1() = 2; add_fn.arg1() = 2;
@ -451,8 +451,8 @@ TEST(TFCompileTest, AssertEqAndReturnDiff) {
// Assert is converted into a no-op in XLA, so there is no failure even if the // Assert is converted into a no-op in XLA, so there is no failure even if the
// two args are different. // two args are different.
AssertComp assert; AssertComp assert;
EXPECT_EQ(assert.arg0_data(), assert.args()[0]); EXPECT_EQ(assert.arg0_data(), assert.arg_data(0));
EXPECT_EQ(assert.arg1_data(), assert.args()[1]); EXPECT_EQ(assert.arg1_data(), assert.arg_data(1));
assert.arg0() = 2; assert.arg0() = 2;
assert.arg1() = 1; assert.arg1() = 1;

View File

@ -160,6 +160,7 @@ cc_library(
"//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
@ -178,6 +179,7 @@ cc_library(
"//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:fifo_queue", "//tensorflow/core/kernels:fifo_queue",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:no_op", "//tensorflow/core/kernels:no_op",
@ -186,6 +188,9 @@ cc_library(
"//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:shape_ops",
"//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels:variable_ops",
"//tensorflow/core/kernels/data:generator_dataset_op",
"//tensorflow/core/kernels/data:iterator_ops",
"//tensorflow/core/kernels/data:prefetch_dataset_op",
], ],
) )

View File

@ -46,6 +46,7 @@ class Predicate {
virtual string ToString() const = 0; virtual string ToString() const = 0;
int64 hash() const { return hash_; } int64 hash() const { return hash_; }
virtual gtl::ArraySlice<Predicate*> GetOperands() const = 0;
virtual Kind kind() const = 0; virtual Kind kind() const = 0;
virtual ~Predicate() {} virtual ~Predicate() {}
@ -90,7 +91,8 @@ class AndPredicate : public Predicate {
Kind kind() const override { return Kind::kAnd; } Kind kind() const override { return Kind::kAnd; }
const gtl::ArraySlice<Predicate*> operands() const { return operands_; } gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
gtl::ArraySlice<Predicate*> operands() const { return operands_; }
private: private:
std::vector<Predicate*> operands_; std::vector<Predicate*> operands_;
@ -117,7 +119,8 @@ class OrPredicate : public Predicate {
} }
Kind kind() const override { return Kind::kOr; } Kind kind() const override { return Kind::kOr; }
const gtl::ArraySlice<Predicate*> operands() const { return operands_; } gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
gtl::ArraySlice<Predicate*> operands() const { return operands_; }
private: private:
std::vector<Predicate*> operands_; std::vector<Predicate*> operands_;
@ -128,17 +131,18 @@ class NotPredicate : public Predicate {
public: public:
explicit NotPredicate(Predicate* operand) explicit NotPredicate(Predicate* operand)
: Predicate(HashPredicateSequence(Kind::kNot, {operand})), : Predicate(HashPredicateSequence(Kind::kNot, {operand})),
operand_(operand) {} operands_({operand}) {}
string ToString() const override { string ToString() const override {
return strings::StrCat("~", operand()->ToString()); return strings::StrCat("~", operand()->ToString());
} }
Kind kind() const override { return Kind::kNot; } Kind kind() const override { return Kind::kNot; }
Predicate* operand() const { return operand_; } Predicate* operand() const { return operands_[0]; }
gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
private: private:
Predicate* operand_; std::array<Predicate*, 1> operands_;
}; };
// Represents an uninterpreted symbol in a logical predicate. // Represents an uninterpreted symbol in a logical predicate.
@ -158,6 +162,7 @@ class SymbolPredicate : public Predicate {
} }
Kind kind() const override { return Kind::kSymbol; } Kind kind() const override { return Kind::kSymbol; }
gtl::ArraySlice<Predicate*> GetOperands() const override { return {}; }
// If `must_be_true()` is true this SymbolPredicate represents the proposition // If `must_be_true()` is true this SymbolPredicate represents the proposition
// "tensor_id() is live and evaluates to true". // "tensor_id() is live and evaluates to true".
@ -288,10 +293,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
if (op->kind() == pred_kind) { if (op->kind() == pred_kind) {
// "Inline" the operands of an inner And/Or into the parent And/Or. // "Inline" the operands of an inner And/Or into the parent And/Or.
gtl::ArraySlice<Predicate*> operands = for (Predicate* subop : op->GetOperands()) {
is_and ? dynamic_cast<AndPredicate*>(op)->operands()
: dynamic_cast<OrPredicate*>(op)->operands();
for (Predicate* subop : operands) {
if (simplified_ops_set.insert(subop).second) { if (simplified_ops_set.insert(subop).second) {
simplified_ops.push_back(subop); simplified_ops.push_back(subop);
} }

View File

@ -1161,8 +1161,7 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef(
strings::StrCat("replace_encapsulate_fdef_", name), fdef); strings::StrCat("replace_encapsulate_fdef_", name), fdef);
} }
TF_RETURN_IF_ERROR(library->RemoveFunction(name)); TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
return Status::OK(); return Status::OK();
} }

View File

@ -16,6 +16,7 @@ cc_library(
"//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/jit:xla_device",
"//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/jit:xla_launch_util",
"//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:client_library",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
@ -199,7 +200,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
run_options.set_stream(stream); run_options.set_stream(stream);
run_options.set_allocator(xla_allocator); run_options.set_allocator(xla_allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(ctx->step_id()); run_options.set_rng_seed(GetXLARandomSeed());
Env* env = Env::Default(); Env* env = Env::Default();
auto start_time = env->NowMicros(); auto start_time = env->NowMicros();

View File

@ -296,7 +296,7 @@ Status XlaCompilationCache::CompileImpl(
// protect the contents of the cache entry. // protect the contents of the cache entry.
Entry* entry; Entry* entry;
{ {
mutex_lock lock(mu_); mutex_lock lock(compile_cache_mu_);
// Find or create a cache entry. // Find or create a cache entry.
std::unique_ptr<Entry>& e = cache_[signature]; std::unique_ptr<Entry>& e = cache_[signature];
if (!e) { if (!e) {
@ -312,6 +312,8 @@ Status XlaCompilationCache::CompileImpl(
if (!entry->compiled) { if (!entry->compiled) {
VLOG(1) << "Compilation cache miss for signature: " VLOG(1) << "Compilation cache miss for signature: "
<< SignatureDebugString(signature); << SignatureDebugString(signature);
tensorflow::Env* env = tensorflow::Env::Default();
const uint64 compile_start_us = env->NowMicros();
// Do the actual JIT compilation without holding the lock (it can take // Do the actual JIT compilation without holding the lock (it can take
// a long time.) // a long time.)
std::vector<XlaCompiler::Argument> args; std::vector<XlaCompiler::Argument> args;
@ -334,6 +336,26 @@ Status XlaCompilationCache::CompileImpl(
CHECK_EQ(entry->executable.get(), nullptr); CHECK_EQ(entry->executable.get(), nullptr);
entry->compilation_status = entry->compilation_status =
BuildExecutable(options, entry->compilation_result, &entry->executable); BuildExecutable(options, entry->compilation_result, &entry->executable);
const uint64 compile_end_us = env->NowMicros();
const uint64 compile_time_us = compile_end_us - compile_start_us;
{
mutex_lock lock(compile_stats_mu_);
auto it = compile_stats_.emplace(function.name(), CompileStats{}).first;
it->second.compile_count++;
it->second.cumulative_compile_time_us += compile_time_us;
VLOG(1) << "compiled " << function.name() << " "
<< it->second.compile_count
<< " times, compile time: " << compile_time_us
<< " us, cumulative: " << it->second.cumulative_compile_time_us
<< " us ("
<< tensorflow::strings::HumanReadableElapsedTime(compile_time_us /
1.0e6)
<< " / "
<< tensorflow::strings::HumanReadableElapsedTime(
it->second.cumulative_compile_time_us / 1.0e6)
<< ")";
}
} }
TF_RETURN_IF_ERROR(entry->compilation_status); TF_RETURN_IF_ERROR(entry->compilation_status);
*compilation_result = &entry->compilation_result; *compilation_result = &entry->compilation_result;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
@ -150,9 +151,22 @@ class XlaCompilationCache : public ResourceBase {
std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu); std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu);
}; };
mutex mu_; mutex compile_cache_mu_;
std::unordered_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_ gtl::FlatMap<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
GUARDED_BY(mu_); GUARDED_BY(compile_cache_mu_);
struct CompileStats {
// Number of times the cluster has been (re-)compiled.
int64 compile_count = 0;
// Cumulative time spent compiling the cluster.
int64 cumulative_compile_time_us = 0;
};
mutex compile_stats_mu_;
// Maps cluster names to compilation statistics for said cluster.
gtl::FlatMap<string, CompileStats> compile_stats_
GUARDED_BY(compile_stats_mu_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
}; };

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -71,7 +72,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
run_options.set_stream(stream); run_options.set_stream(stream);
run_options.set_allocator(client->backend().memory_allocator()); run_options.set_allocator(client->backend().memory_allocator());
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(ctx->step_id()); run_options.set_rng_seed(GetXLARandomSeed());
xla::StatusOr<xla::ScopedShapedBuffer> run_result = xla::StatusOr<xla::ScopedShapedBuffer> run_result =
executable->Run(launch_context.arguments(), run_options); executable->Run(launch_context.arguments(), run_options);

View File

@ -211,17 +211,18 @@ XlaDevice::XlaDevice(
use_multiple_streams), use_multiple_streams),
device_ordinal_(device_ordinal), device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name), jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
platform_(platform), platform_(platform),
use_multiple_streams_(use_multiple_streams), use_multiple_streams_(use_multiple_streams),
transfer_as_literal_(transfer_as_literal), transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) { shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name; VLOG(1) << "Created XLA device " << jit_device_name << " " << this;
} }
XlaDevice::~XlaDevice() { XlaDevice::~XlaDevice() {
if (gpu_device_info_ != nullptr) { VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
gpu_device_info_->default_context->Unref(); mutex_lock lock(mu_);
if (device_context_) {
device_context_->Unref();
} }
} }
@ -237,6 +238,11 @@ xla::LocalClient* XlaDevice::client() const {
} }
Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
mutex_lock lock(mu_);
return GetAllocatorLocked(attr);
}
Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) {
if (attr.on_host()) { if (attr.on_host()) {
return cpu_allocator(); return cpu_allocator();
} }
@ -249,83 +255,105 @@ Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
return xla_allocator_; return xla_allocator_;
} }
xla::StatusOr<se::Stream*> XlaDevice::GetStream() { Status XlaDevice::EnsureDeviceContextOk() {
if (!stream_) { mutex_lock lock(mu_);
xla::Backend* backend = client()->mutable_backend(); return GetDeviceContextLocked().status();
TF_ASSIGN_OR_RETURN(stream_, backend->BorrowStream(device_ordinal_));
}
return stream_.get();
} }
xla::StatusOr<se::Stream*> XlaDevice::GetDeviceToHostStream() { Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
if (!use_multiple_streams_) { const string& name,
return GetStream(); xla::StreamPool::Ptr* stream,
bool* stream_was_changed) {
if (!(*stream) || !(*stream)->ok()) {
TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_));
VLOG(1) << "XlaDevice " << this << " new " << name << " "
<< (*stream)->DebugStreamPointers();
*stream_was_changed = true;
} }
if (!device_to_host_stream_) {
xla::Backend* backend = client()->mutable_backend();
TF_ASSIGN_OR_RETURN(device_to_host_stream_,
backend->BorrowStream(device_ordinal_));
}
return device_to_host_stream_.get();
}
xla::StatusOr<se::Stream*> XlaDevice::GetHostToDeviceStream() {
if (!use_multiple_streams_) {
return GetStream();
}
if (!host_to_device_stream_) {
xla::Backend* backend = client()->mutable_backend();
TF_ASSIGN_OR_RETURN(host_to_device_stream_,
backend->BorrowStream(device_ordinal_));
}
return host_to_device_stream_.get();
}
Status XlaDevice::CreateAndSetGpuDeviceInfo() {
if (gpu_device_info_ == nullptr) {
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
// Call GetAllocator for the side-effect of ensuring the allocator
// is created.
GetAllocator({});
// XlaDevice owns both gpu_device_info_ and
// gpu_device_info_->default_context.
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
gpu_device_info_->stream = stream;
gpu_device_info_->default_context =
new XlaDeviceContext(stream, stream, stream, client(),
transfer_as_literal_, shape_representation_fn_);
set_tensorflow_gpu_device_info(gpu_device_info_.get());
}
return Status::OK(); return Status::OK();
} }
xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
xla::Backend* backend = client()->mutable_backend();
// Ensure all our streams are valid, borrowing new streams if necessary.
bool need_new_device_context = !device_context_;
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
&need_new_device_context));
se::Stream* host_to_device_stream = stream_.get();
se::Stream* device_to_host_stream = stream_.get();
if (use_multiple_streams_) {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
&host_to_device_stream_,
&need_new_device_context));
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream",
&device_to_host_stream_,
&need_new_device_context));
host_to_device_stream = host_to_device_stream_.get();
device_to_host_stream = device_to_host_stream_.get();
}
if (!need_new_device_context) {
return device_context_;
}
// At this point we know we need a new device context.
// Call GetAllocator for the side-effect of ensuring the allocator is created.
GetAllocatorLocked({});
if (device_context_) {
device_context_->Unref();
}
device_context_ = new XlaDeviceContext(
stream_.get(), host_to_device_stream, device_to_host_stream, client(),
transfer_as_literal_, shape_representation_fn_);
VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
<< device_context_;
// Create and set a new GpuDeviceInfo, if necessary.
//
// TODO(b/78232898): This isn't thread-safe; there is a race between the call
// to set_tensorflow_gpu_device_info() with ops that call the getter
// tensorflow_gpu_device_info(). This isn't trivially fixed by adding locking
// to those methods; see the bug for details. Our only saving grace at the
// moment is that this race doesn't seem to occur in practice.
if (use_gpu_device_info_) {
auto gpu_device_info = MakeUnique<GpuDeviceInfo>();
gpu_device_info->stream = stream_.get();
gpu_device_info->default_context = device_context_;
set_tensorflow_gpu_device_info(gpu_device_info.get());
gpu_device_info_ = std::move(gpu_device_info);
VLOG(1) << "XlaDevice " << this << " new GpuDeviceInfo "
<< gpu_device_info_.get();
}
return device_context_;
}
Status XlaDevice::UseGpuDeviceInfo() {
mutex_lock lock(mu_);
use_gpu_device_info_ = true;
return GetDeviceContextLocked().status();
}
Status XlaDevice::FillContextMap(const Graph* graph, Status XlaDevice::FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) { DeviceContextMap* device_context_map) {
VLOG(1) << "XlaDevice::FillContextMap"; VLOG(1) << "XlaDevice::FillContextMap";
device_context_map->resize(graph->num_node_ids()); mutex_lock lock(mu_);
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context,
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream, GetDeviceContextLocked());
GetDeviceToHostStream());
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
GetHostToDeviceStream());
// Call GetAllocator for the side-effect of ensuring the allocator is created. device_context_map->resize(graph->num_node_ids());
GetAllocator({});
auto ctx = new XlaDeviceContext(
stream, host_to_device_stream, device_to_host_stream, client(),
transfer_as_literal_, shape_representation_fn_);
for (Node* n : graph->nodes()) { for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
ctx->Ref(); device_context->Ref();
(*device_context_map)[n->id()] = ctx; (*device_context_map)[n->id()] = device_context;
} }
ctx->Unref();
return Status::OK(); return Status::OK();
} }
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":" VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string(); << op_kernel->type_string();
// When Xprof profiling is off (which is the default), constructing the // When Xprof profiling is off (which is the default), constructing the
// activity is simple enough that its overhead is negligible. // activity is simple enough that its overhead is negligible.
@ -336,7 +364,7 @@ void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) { AsyncOpKernel::DoneCallback done) {
VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
<< op_kernel->type_string(); << op_kernel->type_string();
tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
op_kernel->IsExpensive()); op_kernel->IsExpensive());
@ -358,17 +386,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
if (alloc_attrs.on_host()) { if (alloc_attrs.on_host()) {
*tensor = parsed; *tensor = parsed;
} else { } else {
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); mutex_lock lock(mu_);
TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context,
GetDeviceContextLocked());
Allocator* allocator = GetAllocatorLocked(alloc_attrs);
Tensor copy(allocator, parsed.dtype(), parsed.shape());
Notification n; Notification n;
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); device_context->CopyCPUTensorToDevice(&parsed, this, &copy,
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
GetDeviceToHostStream());
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
GetHostToDeviceStream());
XlaTransferManager manager(stream, host_to_device_stream,
device_to_host_stream, client(),
transfer_as_literal_, shape_representation_fn_);
manager.CopyCPUTensorToDevice(&parsed, this, &copy,
[&n, &status](const Status& s) { [&n, &status](const Status& s) {
status = s; status = s;
n.Notify(); n.Notify();

View File

@ -25,6 +25,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
#include "tensorflow/compiler/jit/xla_device_context.h"
#include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -40,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace tensorflow { namespace tensorflow {
@ -117,62 +119,85 @@ class XlaDevice : public LocalDevice {
const PaddedShapeFn& padded_shape_fn); const PaddedShapeFn& padded_shape_fn);
~XlaDevice() override; ~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override; Allocator* GetAllocator(AllocatorAttributes attr) override
LOCKS_EXCLUDED(mu_);
void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override; AsyncOpKernel::DoneCallback done) override;
Status Sync() override { return Status::OK(); } Status Sync() override { return Status::OK(); }
Status FillContextMap(const Graph* graph, Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override; DeviceContextMap* device_context_map) override
LOCKS_EXCLUDED(mu_);
Status MakeTensorFromProto(const TensorProto& tensor_proto, Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs, const AllocatorAttributes alloc_attrs,
Tensor* tensor) override; Tensor* tensor) override LOCKS_EXCLUDED(mu_);
xla::LocalClient* client() const;
const Metadata& metadata() { return xla_metadata_; } const Metadata& metadata() { return xla_metadata_; }
xla::StatusOr<se::Stream*> GetStream();
xla::StatusOr<se::Stream*> GetHostToDeviceStream();
xla::StatusOr<se::Stream*> GetDeviceToHostStream();
// If not already set, create and set GpuDeviceInfo. // Ensures the DeviceContext associated with this XlaDevice is created and
// Not thread-safe // valid (i.e. all streams are ok). If any state is not valid, a new
Status CreateAndSetGpuDeviceInfo(); // DeviceContext will be created.
//
// TODO(b/111859745): The Eager context needs to call this method to recover
// from failures.
Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_);
// Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra
// information for GPU and TPU devices.
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
private: private:
xla::LocalClient* client() const;
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
xla::StreamPool::Ptr* stream,
bool* stream_was_changed)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutex mu_;
// The metadata of this XlaDevice. // The metadata of this XlaDevice.
const Metadata xla_metadata_; const Metadata xla_metadata_;
// Which hardware device in the client's platform this XlaDevice controls. // Which hardware device in the client's platform this XlaDevice controls.
const int device_ordinal_; const int device_ordinal_;
// The name of the device that is used to compile Ops for this XlaDevice. // The name of the device that is used to compile Ops for this XlaDevice.
DeviceType jit_device_name_; const DeviceType jit_device_name_;
// The platform for this device.
se::Platform* const platform_; // Not owned.
// Memory allocator associated with this device. // Memory allocator associated with this device.
Allocator* xla_allocator_; // Not owned. Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned.
se::Platform* platform_; // Not owned.
// Stream associated with this device. Operations enqueued on this // Stream associated with this device. Operations enqueued on this
// stream are executed on the device. Operations include data // stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and // copying back and forth between CPU and the device, and
// computations enqueued by XLA. // computations enqueued by XLA.
xla::StreamPool::Ptr stream_; xla::StreamPool::Ptr stream_ GUARDED_BY(mu_);
// If true, only stream_ is valid and all computation and transfers use // If false, only stream_ is valid and all computation and transfers use
// stream_. If false, computation is performed by stream_ and transfers are // stream_. If true, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_host_stream. // performed by host_to_device/device_to_host_stream.
bool use_multiple_streams_; const bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this // If use_multiple_streams_, host to device transfers are performed using this
// stream. // stream.
xla::StreamPool::Ptr host_to_device_stream_; xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_);
// If use_multiple_streams_, device to host transfers are performed using this // If use_multiple_streams_, device to host transfers are performed using this
// stream. // stream.
xla::StreamPool::Ptr device_to_host_stream_; xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_);
// Must we use XLA's transfer manager for correct host<->device transfers? if // Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead. // false, we can use ThenMemcpy() instead.
bool transfer_as_literal_; const bool transfer_as_literal_;
XlaCompiler::ShapeRepresentationFn shape_representation_fn_; const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
// If set, holds default device context (that we must Unref) // The device context accessed by all users of the XlaDevice, set by calls to
// and its stream. // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is
std::unique_ptr<GpuDeviceInfo> gpu_device_info_; // also filled in to that struct. XlaDeviceContext is a ref-counted object.
XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr;
// Holds extra information for GPU and TPU devices, e.g. the device context.
bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
}; };
// Builds OpKernel registrations on 'device' for the JIT operators // Builds OpKernel registrations on 'device' for the JIT operators

View File

@ -101,34 +101,27 @@ Status XlaTransferManager::TransferLiteralToDevice(
// Unref the host tensor, and capture the literal shared_ptr too so it goes // Unref the host tensor, and capture the literal shared_ptr too so it goes
// out of scope when the lambda completes. // out of scope when the lambda completes.
host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); }); host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); });
return Status::OK(); return Status::OK();
} }
void XlaTransferManager::TransferLiteralFromDevice( void XlaTransferManager::TransferLiteralFromDevice(
Tensor* host_tensor, const Tensor& device_tensor, Tensor* host_tensor, const Tensor& device_tensor,
const StatusCallback& done) const { const StatusCallback& done) const {
xla::MutableBorrowingLiteral literal;
TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal));
const xla::ShapedBuffer& shaped_buffer = const xla::ShapedBuffer& shaped_buffer =
XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
TensorReference ref(device_tensor); TensorReference ref(device_tensor);
transfer_manager_->TransferLiteralFromDevice( transfer_manager_->TransferLiteralFromDevice(
device_to_host_stream_, shaped_buffer, device_to_host_stream_, shaped_buffer, literal,
[=, &shaped_buffer]( [=, &shaped_buffer, &literal](xla::Status status) {
xla::StatusOr<std::unique_ptr<xla::Literal> > literal_or) {
ref.Unref(); ref.Unref();
done([&]() -> Status { done([&]() -> Status {
TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or)); VLOG(1) << "Transfer from device as literal: " << literal.ToString()
VLOG(1) << "Transfer from device as literal: " << literal->ToString()
<< " " << shaped_buffer.ToString(); << " " << shaped_buffer.ToString();
Tensor tensor;
TF_RETURN_IF_ERROR(
LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
// Reshape the tensor back to its declared shape.
Status status;
if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
status = errors::Internal(
"Tensor::CopyFrom failed when copying from XLA device to CPU");
}
return status; return status;
}()); }());
}); });

View File

@ -23,7 +23,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/cast_op.h"
#include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/kernels/control_flow_ops.h" #include "tensorflow/core/kernels/control_flow_ops.h"
#include "tensorflow/core/kernels/data/generator_dataset_op.h"
#include "tensorflow/core/kernels/data/iterator_ops.h"
#include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
#include "tensorflow/core/kernels/fifo_queue.h" #include "tensorflow/core/kernels/fifo_queue.h"
#include "tensorflow/core/kernels/function_ops.h"
#include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/kernels/no_op.h"
@ -166,7 +170,69 @@ class XlaAssignVariableOp : public AsyncOpKernel {
QueueIsClosedOp); \ QueueIsClosedOp); \
\ \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \
\
REGISTER_KERNEL_BUILDER( \
Name(kArgOp).Device(DEVICE).HostMemory("output").TypeConstraint("T", \
TYPES), \
ArgOp); \
REGISTER_KERNEL_BUILDER(Name(kArgOp) \
.Device(DEVICE) \
.HostMemory("output") \
.TypeConstraint<ResourceHandle>("T"), \
ArgOp); \
\
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
.Device(DEVICE) \
.TypeConstraint("T", TYPES) \
.HostMemory("input"), \
RetvalOp); \
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
.Device(DEVICE) \
.TypeConstraint<ResourceHandle>("T") \
.HostMemory("input"), \
RetvalOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \
GeneratorDatasetOp); \
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \
.Device(DEVICE) \
.HostMemory("buffer_size") \
.HostMemory("input_dataset") \
.HostMemory("handle"), \
PrefetchDatasetOp); \
\
REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \
IteratorHandleOp); \
REGISTER_KERNEL_BUILDER( \
Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \
MakeIteratorOp); \
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \
AnonymousIteratorHandleOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
IteratorGetNextOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \
.Device(DEVICE) \
.HostMemory("string_handle"), \
IteratorToStringHandleOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \
.Device(DEVICE) \
.HostMemory("string_handle"), \
IteratorFromStringHandleOp); \
REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \
.Device(DEVICE) \
.HostMemory("output") \
.TypeConstraint<string>("T"), \
ArgOp); \
REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp) \
.Device(DEVICE) \
.TypeConstraint<string>("T") \
.HostMemory("input"), \
RetvalOp);
// TODO(phawkins): currently we do not register the QueueEnqueueMany, // TODO(phawkins): currently we do not register the QueueEnqueueMany,
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read

View File

@ -59,7 +59,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
} }
// TODO(b/78468222): Uncomment after fixing this bug // TODO(b/78468222): Uncomment after fixing this bug
// status = device->CreateAndSetGpuDeviceInfo(); // status = device->UseGpuDeviceInfo();
// if (!status.ok()) { // if (!status.ok()) {
// errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, // errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT,
// " device"); // " device");

View File

@ -691,11 +691,7 @@ tf_xla_py_test(
size = "small", size = "small",
srcs = ["random_ops_test.py"], srcs = ["random_ops_test.py"],
disabled_backends = [ disabled_backends = [
# TODO(b/110300529): RngNormal doesn't return values with the expected variance
"cpu",
"cpu_ondemand", "cpu_ondemand",
# TODO(b/31361304): enable RNG ops on GPU when parallelized.
"gpu",
], ],
deps = [ deps = [
":xla_test", ":xla_test",

View File

@ -52,6 +52,9 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testBasic(self): def testBasic(self):
for dtype in self.float_types: for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
if dtype == np.float16:
continue
with self.test_session(), self.test_scope(): with self.test_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True) variable_scope.get_variable_scope().set_use_resource(True)
@ -91,6 +94,9 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRate(self): def testTensorLearningRate(self):
for dtype in self.float_types: for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
if dtype == np.float16:
continue
with self.test_session(), self.test_scope(): with self.test_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True) variable_scope.get_variable_scope().set_use_resource(True)
@ -130,6 +136,9 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testSharing(self): def testSharing(self):
for dtype in self.float_types: for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
if dtype == np.float16:
continue
with self.test_session(), self.test_scope(): with self.test_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True) variable_scope.get_variable_scope().set_use_resource(True)

View File

@ -32,6 +32,7 @@ from tensorflow.python.layers import convolutional
from tensorflow.python.layers import pooling from tensorflow.python.layers import pooling
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
@ -122,6 +123,14 @@ class EagerTest(xla_test.XLATestCase):
with self.test_scope(): with self.test_scope():
self.assertAllEqual(2, array_ops.identity(2)) self.assertAllEqual(2, array_ops.identity(2))
def testRandomOps(self):
with self.test_scope():
tensor = gen_random_ops.random_uniform((2, 2), dtypes.float32)
row0 = tensor[0].numpy()
row1 = tensor[1].numpy()
# It should be very unlikely to rng to generate two equal rows.
self.assertFalse((row0 == row1).all())
def testIdentityOnVariable(self): def testIdentityOnVariable(self):
with self.test_scope(): with self.test_scope():
v = resource_variable_ops.ResourceVariable(True) v = resource_variable_ops.ResourceVariable(True)

View File

@ -57,7 +57,8 @@ class RandomOpsTest(xla_test.XLATestCase):
def testRandomUniformIsNotConstant(self): def testRandomUniformIsNotConstant(self):
def rng(dtype): def rng(dtype):
return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) dtype = dtypes.as_dtype(dtype)
return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=dtype.max)
for dtype in self._random_types(): for dtype in self._random_types():
self._testRngIsNotConstant(rng, dtype) self._testRngIsNotConstant(rng, dtype)
@ -73,6 +74,11 @@ class RandomOpsTest(xla_test.XLATestCase):
def testRandomUniformIsInRange(self): def testRandomUniformIsInRange(self):
for dtype in self._random_types(): for dtype in self._random_types():
# TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is
# fixed.
if (self.device in ["XLA_GPU", "XLA_CPU"
]) and (dtype in [dtypes.bfloat16, dtypes.half]):
continue
with self.test_session() as sess: with self.test_session() as sess:
with self.test_scope(): with self.test_scope():
x = random_ops.random_uniform( x = random_ops.random_uniform(
@ -95,7 +101,7 @@ class RandomOpsTest(xla_test.XLATestCase):
for dtype in [dtypes.float32]: for dtype in [dtypes.float32]:
with self.test_session() as sess: with self.test_session() as sess:
with self.test_scope(): with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype, seed=42) x = random_ops.truncated_normal(shape=[count], dtype=dtype)
y = sess.run(x) y = sess.run(x)
def normal_cdf(x): def normal_cdf(x):
@ -124,20 +130,23 @@ class RandomOpsTest(xla_test.XLATestCase):
# Department of Scientific Computing website. Florida State University. # Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
actual_mean = np.mean(y) actual_mean = np.mean(y)
self.assertAllClose(actual_mean, expected_mean, atol=2e-4) self.assertAllClose(actual_mean, expected_mean, atol=2e-3)
expected_median = mu + probit( expected_median = mu + probit(
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
actual_median = np.median(y) actual_median = np.median(y)
self.assertAllClose(actual_median, expected_median, atol=8e-4) self.assertAllClose(actual_median, expected_median, atol=1e-2)
expected_variance = sigma**2 * (1 + ( expected_variance = sigma**2 * (1 + (
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
(normal_pdf(alpha) - normal_pdf(beta)) / z)**2) (normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
actual_variance = np.var(y) actual_variance = np.var(y)
self.assertAllClose(actual_variance, expected_variance, rtol=3e-4) self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3)
def testShuffle1d(self): def testShuffle1d(self):
# TODO(b/26783907): this test requires the CPU backend to implement sort.
if self.device in ["XLA_CPU"]:
return
with self.test_session() as sess: with self.test_session() as sess:
with self.test_scope(): with self.test_scope():
x = math_ops.range(1 << 16) x = math_ops.range(1 << 16)

View File

@ -361,6 +361,12 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([[-0.05, 6.05, 5]], dtype=dtype), np.array([[-0.05, 6.05, 5]], dtype=dtype),
expected=np.array([[0, 6, 5]], dtype=dtype)) expected=np.array([[0, 6, 5]], dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.softmax,
np.array([1, 2, 3, 4], dtype=dtype),
expected=np.array([0.032058604, 0.087144323, 0.23688284, 0.64391428],
dtype=dtype))
self._assertOpOutputMatchesExpected( self._assertOpOutputMatchesExpected(
nn_ops.softmax, nn_ops.softmax,
np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
@ -369,6 +375,14 @@ class UnaryOpsTest(xla_test.XLATestCase):
[0.032058604, 0.087144323, 0.23688284, 0.64391428]], [0.032058604, 0.087144323, 0.23688284, 0.64391428]],
dtype=dtype)) dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.softmax,
np.array([[[1, 1], [1, 1]], [[1, 2], [3, 4]]], dtype=dtype),
expected=np.array(
[[[0.5, 0.5], [0.5, 0.5]],
[[0.26894142, 0.73105858], [0.26894142, 0.73105858]]],
dtype=dtype))
self._assertOpOutputMatchesExpected( self._assertOpOutputMatchesExpected(
nn_ops.softsign, nn_ops.softsign,
np.array([[-2, -1, 0, 1, 2]], dtype=dtype), np.array([[-2, -1, 0, 1, 2]], dtype=dtype),

View File

@ -21,6 +21,8 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_control_flow_ops
@ -47,6 +49,34 @@ class XlaDeviceTest(xla_test.XLATestCase):
result = sess.run(z, {x: inputs}) result = sess.run(z, {x: inputs})
self.assertAllCloseAccordingToType(result, inputs + inputs) self.assertAllCloseAccordingToType(result, inputs + inputs)
def testCopiesOfUnsupportedTypesFailGracefully(self):
"""Tests that copies of unsupported types don't crash."""
test_types = set([
np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32,
np.int64, np.float16, np.float32, np.float16,
dtypes.bfloat16.as_numpy_dtype
])
shape = (10, 10)
for unsupported_dtype in test_types - self.all_types:
with self.test_session() as sess:
with ops.device("CPU"):
x = array_ops.placeholder(unsupported_dtype, shape)
with self.test_scope():
y, = array_ops.identity_n([x])
with ops.device("CPU"):
z = array_ops.identity(y)
inputs = np.random.randint(-100, 100, shape)
inputs = inputs.astype(unsupported_dtype)
# Execution should either succeed or raise an InvalidArgumentError,
# but not crash. Even "unsupported types" may succeed here since some
# backends (e.g., the CPU backend) are happy to handle buffers of
# unsupported types, even if they cannot compute with them.
try:
sess.run(z, {x: inputs})
except errors.InvalidArgumentError:
pass
def testControlTrigger(self): def testControlTrigger(self):
with self.test_session() as sess: with self.test_session() as sess:
with self.test_scope(): with self.test_scope():

View File

@ -95,6 +95,10 @@ cc_library(
name = "cpu_function_runtime", name = "cpu_function_runtime",
srcs = ["cpu_function_runtime.cc"], srcs = ["cpu_function_runtime.cc"],
hdrs = ["cpu_function_runtime.h"], hdrs = ["cpu_function_runtime.h"],
visibility = [
"//tensorflow/compiler/aot:__pkg__",
"//tensorflow/compiler/xla/service/cpu:__pkg__",
],
deps = [ deps = [
# Keep dependencies to a minimum here; this library is used in every AOT # Keep dependencies to a minimum here; this library is used in every AOT
# binary produced by tfcompile. # binary produced by tfcompile.
@ -144,6 +148,7 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/service/cpu:buffer_info_util",
"//tensorflow/compiler/xla/service/cpu:cpu_executable", "//tensorflow/compiler/xla/service/cpu:cpu_executable",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",

View File

@ -55,19 +55,26 @@ size_t align_to(size_t n, size_t align) {
} // namespace } // namespace
namespace cpu_function_runtime { namespace cpu_function_runtime {
size_t AlignedBufferBytes(const intptr_t* sizes, size_t n) { size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n,
bool allocate_entry_params) {
size_t total = 0; size_t total = 0;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
if (sizes[i] > 0) { bool should_allocate =
total += align_to(sizes[i], kAlign); buffer_infos[i].is_temp_buffer() ||
(buffer_infos[i].is_entry_parameter() && allocate_entry_params);
if (should_allocate) {
total += align_to(buffer_infos[i].size(), kAlign);
} }
} }
return total; return total;
} }
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n,
bool allocate_entry_params, void** bufs,
bool annotate_initialized) { bool annotate_initialized) {
const size_t total = AlignedBufferBytes(sizes, n); const size_t total =
AlignedBufferBytes(buffer_infos, n, allocate_entry_params);
void* contiguous = nullptr; void* contiguous = nullptr;
if (total > 0) { if (total > 0) {
contiguous = aligned_malloc(total, kAlign); contiguous = aligned_malloc(total, kAlign);
@ -79,13 +86,14 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
} }
uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous); uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous);
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
if (sizes[i] < 0) { bool should_allocate =
// bufs[i] is either a constant, an entry parameter or a thread local buffer_infos[i].is_temp_buffer() ||
// allocation. (buffer_infos[i].is_entry_parameter() && allocate_entry_params);
bufs[i] = nullptr; if (should_allocate) {
} else {
bufs[i] = reinterpret_cast<void*>(pos); bufs[i] = reinterpret_cast<void*>(pos);
pos += align_to(sizes[i], kAlign); pos += align_to(buffer_infos[i].size(), kAlign);
} else {
bufs[i] = nullptr;
} }
} }
return contiguous; return contiguous;

View File

@ -18,29 +18,142 @@ limitations under the License.
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include <cassert>
namespace tensorflow { namespace tensorflow {
namespace cpu_function_runtime { namespace cpu_function_runtime {
// Stores information about one buffer used by an XLA:CPU compiled function.
// These buffers are used for holding inputs to the computation, outputs from
// the computation and as temporary scratch space.
class BufferInfo {
public:
// Creates a BufferInfo from a serialized encoding generated by `Encode`.
explicit BufferInfo(std::pair<uint64, uint64> encoding)
: entry_param_number_(encoding.second) {
Kind kind;
uint64 size;
Unpack(encoding.first, &kind, &size);
kind_ = kind;
size_ = size;
}
// Returns true if this buffer stores a constant. These never need to be
// allocated by the runtime.
bool is_constant() const { return kind() == Kind::kConstant; }
// Returns true if this buffer stores an entry parameter. These may or may
// not need to be allocated by the runtime, depending on
// XlaCompiledCpuFunction::AllocMode.
bool is_entry_parameter() const { return kind() == Kind::kEntryParameter; }
// Returns the entry parameter number of this buffer.
uint64 entry_parameter_number() const {
assert(is_entry_parameter());
return entry_param_number_;
}
// Returns true if this buffer is temporary scratch space required by the XLA
// computations. These are always allocated by the runtime.
bool is_temp_buffer() const { return kind() == Kind::kTempBuffer; }
// Returns true if this buffer is allocated on the C stack or into registers.
// These buffers are never allocated by the runtime.
bool is_on_stack_buffer() const { return kind() == Kind::kOnStackBuffer; }
// Returns the size for this buffer.
uint64 size() const { return size_; }
// Encodes this BufferInfo into two 64 bit integers that can be used to
// reconstruct the BufferInfo later using the constructor. We need this
// because we use BufferInfo in places where using protocol buffers would
// negatively impact binary size.
std::pair<uint64, uint64> Encode() const {
static_assert(sizeof(*this) == 16, "");
uint64 upper = Pack(kind(), size_);
uint64 lower = entry_param_number_;
return {upper, lower};
}
bool operator==(const BufferInfo& buffer_info) const {
if (kind() != buffer_info.kind() || size() != buffer_info.size()) {
return false;
}
return !is_entry_parameter() ||
entry_parameter_number() == buffer_info.entry_parameter_number();
}
// Factory methods:
static BufferInfo MakeTempBuffer(uint64 size) {
return BufferInfo(Kind::kTempBuffer, /*size=*/size,
/*entry_param_number=*/-1);
}
static BufferInfo MakeConstant(uint64 size) {
return BufferInfo(Kind::kConstant, /*size=*/size,
/*entry_param_number=*/-1);
}
static BufferInfo MakeEntryParameter(uint64 size, uint64 param_number) {
return BufferInfo(Kind::kEntryParameter, /*size=*/size,
/*entry_param_number=*/param_number);
}
static BufferInfo MakeOnStackBuffer(uint64 size) {
return BufferInfo(Kind::kOnStackBuffer, /*size=*/size,
/*entry_param_number=*/-1);
}
private:
BufferInfo() = default;
enum class Kind : unsigned {
kConstant,
kTempBuffer,
kEntryParameter,
kOnStackBuffer
};
Kind kind() const { return static_cast<Kind>(kind_); }
explicit BufferInfo(Kind kind, uint64 size, uint64 entry_param_number)
: kind_(kind), size_(size), entry_param_number_(entry_param_number) {}
static uint64 Pack(Kind kind, uint64 size) {
return (static_cast<uint64>(size) << 2) | static_cast<uint64>(kind);
}
static void Unpack(uint64 packed, Kind* kind, uint64* size) {
*size = packed >> 2;
*kind = static_cast<Kind>((packed << 62) >> 62);
}
Kind kind_ : 2;
uint64 size_ : 62;
int64 entry_param_number_;
};
// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. // Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment.
constexpr size_t kAlign = 64; constexpr size_t kAlign = 64;
// AlignedBufferBytes returns the sum of each size in `sizes`, skipping -1 // AlignedBufferBytes returns the sum of the size of each buffer in
// values. There are `n` entries in `sizes`. Each buffer is aligned to // `buffer_infos`, skipping constants, on-stack buffers and, if
// kAlign byte boundaries. // allocate_entry_params is false, entry parameters. There are `n` entries in
size_t AlignedBufferBytes(const intptr_t* sizes, size_t n); // `buffer_infos`. Each buffer is aligned to kAlign byte boundaries.
size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n,
bool allocate_entry_params);
// MallocContiguousBuffers allocates buffers for use by the entry point // MallocContiguousBuffers allocates buffers for use by the entry point
// generated by tfcompile. `sizes` is an array of byte sizes for each buffer, // generated by tfcompile. There are `n` entries in `buffer_infos`. If
// where -1 causes the buffer pointer to be nullptr. There are `n` entries in // `annotate_initialized` is set, the allocated memory will be annotated as
// `sizes`. If `annotate_initialized` is set, the allocated memory will be // having been initialized - this is useful when allocating temporary buffers.
// annotated as having been initialized - this is useful when allocating // If allocate_entry_params is true then allocates temp buffers and entry
// temporary buffers. // parameters, otherwise allocated only temp buffers. Slots in `bufs`
// corresponding to unallocated buffers are set to nullptr.
// //
// A single contiguous block of memory is allocated, and portions of it are // A single contiguous block of memory is allocated, and portions of it are
// parceled out into `bufs`, which must have space for `n` entries. Returns // parceled out into `bufs`, which must have space for `n` entries. Returns
// the head of the allocated contiguous block, which should be passed to // the head of the allocated contiguous block, which should be passed to
// FreeContiguous when the buffers are no longer in use. // FreeContiguous when the buffers are no longer in use.
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n,
bool allocate_entry_params, void** bufs,
bool annotate_initialized); bool annotate_initialized);
// FreeContiguous frees the contiguous block of memory allocated by // FreeContiguous frees the contiguous block of memory allocated by

View File

@ -21,6 +21,8 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace { namespace {
using cpu_function_runtime::BufferInfo;
TEST(XlaCompiledCpuFunctionTest, AlignmentValue) { TEST(XlaCompiledCpuFunctionTest, AlignmentValue) {
// We've chosen 64 byte alignment for the tfcompile runtime to mimic the // We've chosen 64 byte alignment for the tfcompile runtime to mimic the
// regular tensorflow allocator, which was chosen to play nicely with Eigen. // regular tensorflow allocator, which was chosen to play nicely with Eigen.
@ -30,20 +32,51 @@ TEST(XlaCompiledCpuFunctionTest, AlignmentValue) {
EXPECT_EQ(cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment); EXPECT_EQ(cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment);
} }
std::vector<BufferInfo> SizesToBufferInfos(const intptr_t* sizes, size_t n) {
std::vector<BufferInfo> buffer_infos;
std::transform(sizes, sizes + n, std::back_inserter(buffer_infos),
[&](intptr_t size) {
if (size == -1) {
// Use a dummy on-stack buffer allocation to indicat the
// the current slot does not need an allocation.
int64 on_stack_buffer_size = 4;
return BufferInfo::MakeOnStackBuffer(on_stack_buffer_size);
}
return BufferInfo::MakeTempBuffer(size);
});
return buffer_infos;
}
// Simple wrappers to make writing tests more ergonomic.
size_t AlignedBufferBytesFromSizes(const intptr_t* sizes, size_t n) {
std::vector<BufferInfo> buffer_infos = SizesToBufferInfos(sizes, n);
return AlignedBufferBytes(buffer_infos.data(), n,
/*allocate_entry_params=*/false);
}
void* MallocContiguousBuffersFromSizes(const intptr_t* sizes, size_t n,
void** bufs, bool annotate_initialized) {
std::vector<BufferInfo> buffer_infos = SizesToBufferInfos(sizes, n);
return MallocContiguousBuffers(buffer_infos.data(), n,
/*allocate_entry_params=*/false, bufs,
annotate_initialized);
}
TEST(XlaCompiledCpuFunctionTest, AlignedBufferBytes) { TEST(XlaCompiledCpuFunctionTest, AlignedBufferBytes) {
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(nullptr, 0), 0); EXPECT_EQ(AlignedBufferBytesFromSizes(nullptr, 0), 0);
static constexpr intptr_t sizesA[1] = {-1}; static constexpr intptr_t sizesA[1] = {-1};
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesA, 1), 0); EXPECT_EQ(AlignedBufferBytesFromSizes(sizesA, 1), 0);
static constexpr intptr_t sizesB[1] = {3}; static constexpr intptr_t sizesB[1] = {3};
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesB, 1), 64); EXPECT_EQ(AlignedBufferBytesFromSizes(sizesB, 1), 64);
static constexpr intptr_t sizesC[1] = {32}; static constexpr intptr_t sizesC[1] = {32};
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesC, 1), 64); EXPECT_EQ(AlignedBufferBytesFromSizes(sizesC, 1), 64);
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesD, 7), 320); EXPECT_EQ(AlignedBufferBytesFromSizes(sizesD, 7), 320);
} }
void* add_ptr(void* base, uintptr_t delta) { void* add_ptr(void* base, uintptr_t delta) {
@ -56,15 +89,14 @@ void* add_ptr(void* base, uintptr_t delta) {
// free. We also check the contiguous property. // free. We also check the contiguous property.
TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test empty sizes. // Test empty sizes.
void* base = void* base = MallocContiguousBuffersFromSizes(nullptr, 0, nullptr, false);
cpu_function_runtime::MallocContiguousBuffers(nullptr, 0, nullptr, false);
EXPECT_EQ(base, nullptr); EXPECT_EQ(base, nullptr);
cpu_function_runtime::FreeContiguous(base); cpu_function_runtime::FreeContiguous(base);
// Test non-empty sizes with 0 sum. // Test non-empty sizes with 0 sum.
static constexpr intptr_t sizesA[1] = {-1}; static constexpr intptr_t sizesA[1] = {-1};
void* bufA[1]; void* bufA[1];
base = cpu_function_runtime::MallocContiguousBuffers(sizesA, 1, bufA, false); base = MallocContiguousBuffersFromSizes(sizesA, 1, bufA, false);
EXPECT_EQ(base, nullptr); EXPECT_EQ(base, nullptr);
EXPECT_EQ(bufA[0], nullptr); EXPECT_EQ(bufA[0], nullptr);
cpu_function_runtime::FreeContiguous(base); cpu_function_runtime::FreeContiguous(base);
@ -72,7 +104,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test non-empty sizes with non-0 sum. // Test non-empty sizes with non-0 sum.
static constexpr intptr_t sizesB[1] = {3}; static constexpr intptr_t sizesB[1] = {3};
void* bufB[1]; void* bufB[1];
base = cpu_function_runtime::MallocContiguousBuffers(sizesB, 1, bufB, false); base = MallocContiguousBuffersFromSizes(sizesB, 1, bufB, false);
EXPECT_NE(base, nullptr); EXPECT_NE(base, nullptr);
EXPECT_EQ(bufB[0], add_ptr(base, 0)); EXPECT_EQ(bufB[0], add_ptr(base, 0));
char* bufB0_bytes = static_cast<char*>(bufB[0]); char* bufB0_bytes = static_cast<char*>(bufB[0]);
@ -84,7 +116,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test non-empty sizes with non-0 sum, and annotate_initialized. // Test non-empty sizes with non-0 sum, and annotate_initialized.
static constexpr intptr_t sizesC[1] = {3}; static constexpr intptr_t sizesC[1] = {3};
void* bufC[1]; void* bufC[1];
base = cpu_function_runtime::MallocContiguousBuffers(sizesC, 1, bufC, true); base = MallocContiguousBuffersFromSizes(sizesC, 1, bufC, true);
EXPECT_NE(base, nullptr); EXPECT_NE(base, nullptr);
EXPECT_EQ(bufC[0], add_ptr(base, 0)); EXPECT_EQ(bufC[0], add_ptr(base, 0));
char* bufC0_bytes = static_cast<char*>(bufC[0]); char* bufC0_bytes = static_cast<char*>(bufC[0]);
@ -96,7 +128,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test mixed sizes. // Test mixed sizes.
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
void* bufD[7]; void* bufD[7];
base = cpu_function_runtime::MallocContiguousBuffers(sizesD, 7, bufD, false); base = MallocContiguousBuffersFromSizes(sizesD, 7, bufD, false);
EXPECT_NE(base, nullptr); EXPECT_NE(base, nullptr);
EXPECT_EQ(bufD[0], add_ptr(base, 0)); EXPECT_EQ(bufD[0], add_ptr(base, 0));
EXPECT_EQ(bufD[1], nullptr); EXPECT_EQ(bufD[1], nullptr);
@ -117,5 +149,23 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
cpu_function_runtime::FreeContiguous(base); cpu_function_runtime::FreeContiguous(base);
} }
void CheckRoundTripIsOk(const BufferInfo& buffer_info) {
BufferInfo round_trip(buffer_info.Encode());
ASSERT_EQ(round_trip, buffer_info);
}
TEST(XlaCompiledCpuFunctionTest, BufferInfoTest) {
CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(0));
CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(4));
CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(0));
CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(4));
CheckRoundTripIsOk(BufferInfo::MakeConstant(0));
CheckRoundTripIsOk(BufferInfo::MakeConstant(4));
CheckRoundTripIsOk(
BufferInfo::MakeEntryParameter(/*size=*/0, /*param_number=*/4));
CheckRoundTripIsOk(
BufferInfo::MakeEntryParameter(/*size=*/4, /*param_number=*/0));
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -6,6 +6,10 @@ package(
load("//tensorflow:tensorflow.bzl", "tf_copts") load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
)
tf_kernel_library( tf_kernel_library(
name = "xla_ops", name = "xla_ops",
@ -129,6 +133,7 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/lib:pooling",
"//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/lib:prng",
"//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -153,8 +158,14 @@ tf_kernel_library(
"//tensorflow/core/kernels:sparse_to_dense_op", "//tensorflow/core/kernels:sparse_to_dense_op",
"//tensorflow/core/kernels:stack_ops", "//tensorflow/core/kernels:stack_ops",
"//tensorflow/core/kernels:training_ops", "//tensorflow/core/kernels:training_ops",
] + if_mkl(
[
"//tensorflow/core/kernels:mkl_transpose_op",
],
[
"//tensorflow/core/kernels:transpose_op", "//tensorflow/core/kernels:transpose_op",
], ],
),
) )
tf_kernel_library( tf_kernel_library(

View File

@ -65,6 +65,6 @@ class XlaArgOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp); TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp);
}; };
REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes(), XlaArgOp); REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes().CompilationOnly(), XlaArgOp);
} // namespace tensorflow } // namespace tensorflow

View File

@ -200,12 +200,23 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
} }
} }
bool resource_variable_seen = false;
for (int i = 0; i < ctx->num_inputs(); ++i) {
if (ctx->input_type(i) == DT_RESOURCE) {
resource_variable_seen = true;
} else {
OP_REQUIRES(
ctx, !resource_variable_seen,
errors::FailedPrecondition(
"Resource variables and regular inputs cannot be interleaved."));
}
}
xla::XlaOp outputs = xla::Conditional( xla::XlaOp outputs = xla::Conditional(
ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation, ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation,
xla::Tuple(b, inputs), *else_result.computation); xla::Tuple(b, inputs), *else_result.computation);
// Sets non-variable outputs. // Sets non-variable outputs.
for (int i = 0; i < output_types_.size(); ++i) { for (int i = 0; i < output_types_.size(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
xla::XlaOp output_handle = xla::GetTupleElement(outputs, i); xla::XlaOp output_handle = xla::GetTupleElement(outputs, i);
if (VLOG_IS_ON(2)) { if (VLOG_IS_ON(2)) {
LOG(INFO) << "Setting output " << i; LOG(INFO) << "Setting output " << i;
@ -219,7 +230,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
} }
ctx->SetOutput(i, output_handle); ctx->SetOutput(i, output_handle);
} }
}
// Updates the values of any resource variables modified by the conditional // Updates the values of any resource variables modified by the conditional
// bodies. // bodies.
@ -247,6 +257,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
} }
REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp); REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp);
REGISTER_XLA_OP(Name("StatelessIf").AllowResourceTypes(), XlaIfOp);
REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp);
} // namespace tensorflow } // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/pooling.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal.h"
@ -71,59 +72,53 @@ class PoolingOp : public XlaOpKernel {
int num_dims() const { return num_spatial_dims_ + 2; } int num_dims() const { return num_spatial_dims_ + 2; }
// Method that builds an initial value to use in reductions. protected:
virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0; xla::StatusOr<std::vector<int64>> GetKernelSize(XlaOpKernelContext* ctx) {
if (ctx->num_inputs() == 1) {
// The reduction operation to apply to each window. return ksize_;
virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0; }
// A post-processing operation to apply on the outputs of the ReduceWindow.
virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
const xla::XlaOp& output, DataType dtype,
const TensorShape& input_shape) = 0;
void Compile(XlaOpKernelContext* ctx) override {
std::vector<int64> ksize = ksize_;
std::vector<int64> stride = stride_;
if (ctx->num_inputs() != 1) {
const TensorShape ksize_shape = ctx->InputShape(1); const TensorShape ksize_shape = ctx->InputShape(1);
// Validate input sizes. // Validate input sizes.
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), if (!TensorShapeUtils::IsVector(ksize_shape)) {
errors::InvalidArgument("ksize must be a vector, not shape ", return errors::InvalidArgument("ksize must be a vector, not shape ",
ksize_shape.DebugString())); ksize_shape.DebugString());
OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(), }
errors::InvalidArgument("Sliding window ksize field must " if (ksize_shape.num_elements() != num_dims()) {
return errors::InvalidArgument(
"Sliding window ksize field must "
"specify ", "specify ",
num_dims(), " dimensions")); num_dims(), " dimensions");
ksize.clear(); }
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize)); std::vector<int64> ksize;
auto status = ctx->ConstantInputAsIntVector(1, &ksize);
if (!status.ok()) {
return status;
}
return ksize;
}
xla::StatusOr<std::vector<int64>> GetStride(XlaOpKernelContext* ctx) {
if (ctx->num_inputs() == 1) {
return stride_;
}
const TensorShape stride_shape = ctx->InputShape(2); const TensorShape stride_shape = ctx->InputShape(2);
// Validate input sizes. // Validate input sizes.
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), if (!TensorShapeUtils::IsVector(stride_shape)) {
errors::InvalidArgument("stride must be a vector, not shape ", return errors::InvalidArgument("stride must be a vector, not shape ",
stride_shape.DebugString())); stride_shape.DebugString());
OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(),
errors::InvalidArgument("Sliding window stride field must "
"specify ",
num_dims(), " dimensions"));
stride.clear();
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride));
} }
const TensorShape input_shape = ctx->InputShape(0); if (stride_shape.num_elements() != num_dims()) {
OP_REQUIRES(ctx, input_shape.dims() == num_dims(), return errors::InvalidArgument(
errors::InvalidArgument("Input to ", type_string(), "Sliding window stride field must "
" operator must have ", num_dims(), "specify ",
" dimensions")); num_dims(), " dimensions");
}
xla::XlaBuilder* const b = ctx->builder(); std::vector<int64> stride;
auto input = auto status = ctx->ConstantInputAsIntVector(2, &stride);
XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); if (!status.ok()) {
auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize, return status;
stride, padding_); }
auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); return stride;
ctx->SetOutput(0,
PostProcessOutput(ctx, pooled, input_type(0), input_shape));
} }
protected: protected:
@ -136,24 +131,48 @@ class PoolingOp : public XlaOpKernel {
xla::PrimitiveType xla_reduction_type_; xla::PrimitiveType xla_reduction_type_;
}; };
// Converts the tensor data format to the one required by the XLA pooling
// library.
xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
int num_spatial_dims) {
int num_dims = num_spatial_dims + 2;
int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
gtl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
spatial_dimensions[spatial_dim] =
GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
}
return xla::TensorFormat(/*batch_dimension=*/batch_dimension,
/*feature_dimension=*/feature_dimension,
/*spatial_dimensions=*/spatial_dimensions);
}
class MaxPoolOp : public PoolingOp { class MaxPoolOp : public PoolingOp {
public: public:
MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
: PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
/*reduction_type=*/ctx->input_type(0)) {} /*reduction_type=*/ctx->input_type(0)) {}
xla::XlaOp InitValue(xla::XlaBuilder* b) override { void Compile(XlaOpKernelContext* ctx) override {
return xla::MinValue(b, xla_reduction_type_); auto ksize_or_error = GetKernelSize(ctx);
} OP_REQUIRES_OK(ctx, ksize_or_error.status());
std::vector<int64> ksize = ksize_or_error.ValueOrDie();
const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { auto stride_or_error = GetStride(ctx);
return ctx->GetOrCreateMax(reduction_type_); OP_REQUIRES_OK(ctx, stride_or_error.status());
} std::vector<int64> stride = stride_or_error.ValueOrDie();
xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, const TensorShape input_shape = ctx->InputShape(0);
const xla::XlaOp& output, DataType dtype, OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
const TensorShape& input_shape) override { errors::InvalidArgument("Input to ", type_string(),
return output; " operator must have ", num_dims(),
" dimensions"));
auto pooling =
xla::MaxPool(ctx->Input(0), ksize, stride, padding_,
XlaTensorFormat(data_format_, input_shape.dims() - 2));
ctx->SetOutput(0, pooling);
} }
}; };
@ -180,9 +199,8 @@ class MaxPool3DOp : public MaxPoolOp {
}; };
REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
// Common computation shared between AvgPool and AvgPoolGrad. Divide each // Divide each element of an image by the count of elements that contributed to
// element of an image by the count of elements that contributed to that // that element during pooling.
// element during pooling.
static xla::XlaOp AvgPoolDivideByCount( static xla::XlaOp AvgPoolDivideByCount(
XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype,
const TensorShape& input_shape, xla::Padding padding, const TensorShape& input_shape, xla::Padding padding,
@ -241,20 +259,34 @@ class AvgPoolOp : public PoolingOp {
/*reduction_type=*/ /*reduction_type=*/
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitValue(xla::XlaBuilder* b) override { void Compile(XlaOpKernelContext* ctx) override {
return xla::Zero(b, xla_reduction_type_); auto ksize_or_error = GetKernelSize(ctx);
} OP_REQUIRES_OK(ctx, ksize_or_error.status());
std::vector<int64> ksize = ksize_or_error.ValueOrDie();
const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { auto stride_or_error = GetStride(ctx);
return ctx->GetOrCreateAdd(reduction_type_); OP_REQUIRES_OK(ctx, stride_or_error.status());
} std::vector<int64> stride = stride_or_error.ValueOrDie();
xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, const TensorShape input_shape = ctx->InputShape(0);
const xla::XlaOp& output, DataType dtype, OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
const TensorShape& input_shape) override { errors::InvalidArgument("Input to ", type_string(),
return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, " operator must have ", num_dims(),
ksize_, stride_, num_spatial_dims_, " dimensions"));
data_format_);
auto xla_data_format =
XlaTensorFormat(data_format_, input_shape.dims() - 2);
auto spatial_padding = MakeSpatialPadding(
input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format);
// Convert the input to the reduction type.
auto converted_input =
ConvertElementType(ctx->Input(0), xla_reduction_type_);
auto pooling =
xla::AvgPool(converted_input, ksize, stride, spatial_padding,
xla_data_format, padding_ == xla::Padding::kValid);
// Convert the pooling result back to the input type before returning it.
ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0)));
} }
}; };

View File

@ -104,7 +104,7 @@ class RetvalOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
}; };
REGISTER_XLA_OP(Name("_Retval"), RetvalOp); REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp);
} // anonymous namespace } // anonymous namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -38,11 +38,15 @@ class SoftmaxOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
const TensorShape logits_shape = ctx->InputShape(0); const TensorShape logits_shape = ctx->InputShape(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(logits_shape),
errors::InvalidArgument("logits must be 2-dimensional")); errors::InvalidArgument("logits must have >= 1 dimension, got ",
logits_shape.DebugString()));
const int kBatchDim = 0; // Major dimensions are batch dimensions, minor dimension is the class
const int kClassDim = 1; // dimension.
std::vector<int64> batch_dims(logits_shape.dims() - 1);
std::iota(batch_dims.begin(), batch_dims.end(), 0);
const int kClassDim = logits_shape.dims() - 1;
const DataType type = input_type(0); const DataType type = input_type(0);
const xla::PrimitiveType xla_type = ctx->input_xla_type(0); const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
@ -56,7 +60,7 @@ class SoftmaxOp : public XlaOpKernel {
xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim}); xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim});
// Subtract the max in batch b from every element in batch b. Broadcasts // Subtract the max in batch b from every element in batch b. Broadcasts
// along the batch dimension. // along the batch dimension.
auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); auto shifted_logits = xla::Sub(logits, logits_max, batch_dims);
auto exp_shifted = xla::Exp(shifted_logits); auto exp_shifted = xla::Exp(shifted_logits);
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
xla::PrimitiveType xla_accumulation_type; xla::PrimitiveType xla_accumulation_type;
@ -71,9 +75,9 @@ class SoftmaxOp : public XlaOpKernel {
auto softmax = auto softmax =
log_ log_
// softmax = shifted_logits - log(sum(exp(shifted_logits))) // softmax = shifted_logits - log(sum(exp(shifted_logits)))
? xla::Sub(shifted_logits, xla::Log(sum), {kBatchDim}) ? xla::Sub(shifted_logits, xla::Log(sum), batch_dims)
// softmax = exp(shifted_logits) / sum(exp(shifted_logits)) // softmax = exp(shifted_logits) / sum(exp(shifted_logits))
: xla::Div(exp_shifted, sum, {kBatchDim}); : xla::Div(exp_shifted, sum, batch_dims);
ctx->SetOutput(0, softmax); ctx->SetOutput(0, softmax);
} }

View File

@ -301,6 +301,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
} }
REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp); REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp);
REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp);
REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp);
} // namespace tensorflow } // namespace tensorflow

View File

@ -32,6 +32,23 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
return Status::OK(); return Status::OK();
} }
Status HostTensorToMutableBorrowingLiteral(
Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor->dtype(),
host_tensor->shape(), &xla_shape));
return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal);
}
Status HostTensorToMutableBorrowingLiteral(
const xla::Shape& xla_shape, Tensor* host_tensor,
xla::MutableBorrowingLiteral* literal) {
*literal = xla::MutableBorrowingLiteral(
static_cast<const char*>(DMAHelper::base(host_tensor)), xla_shape);
return Status::OK();
}
Status HostTensorsToBorrowingLiteralTuple( Status HostTensorsToBorrowingLiteralTuple(
tensorflow::gtl::ArraySlice<Tensor> host_tensors, tensorflow::gtl::ArraySlice<Tensor> host_tensors,
xla::BorrowingLiteral* literal) { xla::BorrowingLiteral* literal) {

View File

@ -30,6 +30,16 @@ namespace tensorflow {
// 'host_tensor'. // 'host_tensor'.
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
xla::BorrowingLiteral* literal); xla::BorrowingLiteral* literal);
// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer
// owned by 'host_tensor', but is mutable via the xla::Literal methods.
Status HostTensorToMutableBorrowingLiteral(
Tensor* host_tensor, xla::MutableBorrowingLiteral* literal);
// Similar as above, except the literal shape is explicitly provided and used
// instead of obtaining it from the 'host_tensor'. The provided literal shape
// 'xla_shape' must be compatible with the shape of 'host_tensor'.
Status HostTensorToMutableBorrowingLiteral(
const xla::Shape& xla_shape, Tensor* host_tensor,
xla::MutableBorrowingLiteral* literal);
// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers // Returns a BorrowingLiteral tuple that utilizes the same underlying buffers
// owned by 'host_tensors'. // owned by 'host_tensors'.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include <queue> #include <queue>
#include <random>
#include <set> #include <set>
#include <unordered_map> #include <unordered_map>
@ -297,4 +298,29 @@ void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
} }
} }
namespace {
uint32 InitialRandomSeed() {
// Support plumbing the TF seed through to XLA is being worked on.
// If a user wants deterministic behavior, their best option
// is to start with a known checkpoint. This also handles issues when
// multiple random calls can be invoked in any order by TF executor.
// Another option is to use stateless random ops. They have much cleaner
// semantics.
// If a user really wants to set a deterministic seed for XLA-based
// devices, this is the place to do it.
std::random_device rd;
// Make the starting value odd.
return rd() | 1;
}
} // namespace
uint32 GetXLARandomSeed() {
// We initialize counter with an odd number and increment it by two
// everytime. This ensures that it will never be zero, even
// after an overflow. When seeded with zero, some XLA backends
// can return all zeros instead of random numbers.
static std::atomic<uint32> counter(InitialRandomSeed());
return counter.fetch_add(2);
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -56,6 +56,9 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
KernelDef* kdef); KernelDef* kdef);
// Returns the next random seed to use for seeding xla rng.
uint32 GetXLARandomSeed();
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_

View File

@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include <cassert> #include <cassert>
@ -22,61 +21,42 @@ namespace tensorflow {
XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
AllocMode alloc_mode) AllocMode alloc_mode)
: raw_function_(static_data.raw_function), : raw_function_(static_data.raw_function_),
result_index_(static_data.result_index), result_index_(static_data.result_index_),
args_(new void*[static_data.num_args]), buffer_table_(new void*[static_data.num_buffers_]),
temps_(new void*[static_data.num_temps]), buffer_infos_(static_data.buffer_infos_),
arg_index_to_temp_index_(new int32[static_data.num_args]), arg_index_table_(static_data.arg_index_table_),
num_args_(static_data.num_args), num_args_(static_data.num_args_),
arg_names_(static_data.arg_names), arg_names_(static_data.arg_names_),
result_names_(static_data.result_names), result_names_(static_data.result_names_),
program_shape_(static_data.program_shape), program_shape_(static_data.program_shape_),
hlo_profile_printer_data_(static_data.hlo_profile_printer_data) { hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) {
bool allocate_entry_params =
alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS;
// Allocate arg and temp buffers. // Allocate arg and temp buffers.
if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { alloc_buffer_table_ = cpu_function_runtime::MallocContiguousBuffers(
alloc_args_ = cpu_function_runtime::MallocContiguousBuffers( static_data.buffer_infos_, static_data.num_buffers_,
static_data.arg_sizes, static_data.num_args, args_, /*allocate_entry_params=*/allocate_entry_params, buffer_table_,
/*annotate_initialized=*/false);
}
alloc_temps_ = cpu_function_runtime::MallocContiguousBuffers(
static_data.temp_sizes, static_data.num_temps, temps_,
/*annotate_initialized=*/true); /*annotate_initialized=*/true);
for (int i = 0; i < static_data.num_temps; i++) {
if (static_data.temp_sizes[i] < -1) {
int32 param_number = -(static_data.temp_sizes[i] + 2);
arg_index_to_temp_index_[param_number] = i;
}
}
// If Hlo profiling is enabled the generated code expects an appropriately // If Hlo profiling is enabled the generated code expects an appropriately
// sized buffer to be passed in as the last argument. If Hlo profiling is // sized buffer to be passed in as the last argument. If Hlo profiling is
// disabled the last function argument is still present in the function // disabled the last function argument is still present in the function
// signature, but it is ignored by the generated code and we pass in null for // signature, but it is ignored by the generated code and we pass in null for
// it. // it.
if (hlo_profiling_enabled()) { if (hlo_profiling_enabled()) {
profile_counters_ = new int64[static_data.profile_counters_size](); profile_counters_ = new int64[static_data.profile_counters_size_]();
} }
} }
bool XlaCompiledCpuFunction::Run() { bool XlaCompiledCpuFunction::Run() {
// Propagate pointers to the argument buffers into the temps array. Code raw_function_(buffer_table_[result_index_], &run_options_, nullptr,
// generated by XLA discovers the incoming argument pointers from the temps buffer_table_, profile_counters_);
// array.
for (int32 i = 0; i < num_args_; i++) {
temps_[arg_index_to_temp_index_[i]] = args_[i];
}
raw_function_(temps_[result_index_], &run_options_, nullptr, temps_,
profile_counters_);
return true; return true;
} }
XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
cpu_function_runtime::FreeContiguous(alloc_args_); cpu_function_runtime::FreeContiguous(alloc_buffer_table_);
cpu_function_runtime::FreeContiguous(alloc_temps_); delete[] buffer_table_;
delete[] args_;
delete[] temps_;
delete[] arg_index_to_temp_index_;
delete[] profile_counters_; delete[] profile_counters_;
} }

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <cassert> #include <cassert>
#include <string> #include <string>
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -56,46 +57,85 @@ class XlaCompiledCpuFunction {
// StaticData represents the state necessary to run an XLA-compiled // StaticData represents the state necessary to run an XLA-compiled
// function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for
// AOT this is backed by data compiled into the object file. // AOT this is backed by data compiled into the object file.
struct StaticData { //
// The contents of StaticData are XLA-internal implementation details and
// should not be relied on by clients.
//
// TODO(sanjoy): Come up with a cleaner way to express the contraint we want
// here: generated XlaCompiledCpuFunction subclasses should be able to create
// instances of StaticData but only XlaCompiledCpuFunction should be able to
// read from StaticData instances.
class StaticData {
public:
void set_raw_function(RawFunction raw_function) {
raw_function_ = raw_function;
}
void set_buffer_infos(
const cpu_function_runtime::BufferInfo* buffer_infos) {
buffer_infos_ = buffer_infos;
}
void set_num_buffers(size_t num_buffers) { num_buffers_ = num_buffers; }
void set_arg_index_table(const int32* arg_index_table) {
arg_index_table_ = arg_index_table;
}
void set_num_args(int64 num_args) { num_args_ = num_args; }
void set_result_index(size_t result_index) { result_index_ = result_index; }
void set_arg_names(const char** arg_names) { arg_names_ = arg_names; }
void set_result_names(const char** result_names) {
result_names_ = result_names;
}
void set_program_shape(const xla::ProgramShape* program_shape) {
program_shape_ = program_shape;
}
const xla::HloProfilePrinterData* hlo_profile_printer_data() const {
return hlo_profile_printer_data_;
}
void set_hlo_profile_printer_data(
const xla::HloProfilePrinterData* hlo_profile_printer_data) {
hlo_profile_printer_data_ = hlo_profile_printer_data;
}
void set_profile_counters_size(int64 profile_counters_size) {
profile_counters_size_ = profile_counters_size;
}
private:
// The raw function to call. // The raw function to call.
RawFunction raw_function; RawFunction raw_function_;
// Cardinality and size of arg buffers. // Contains information about the buffers used by the XLA computation.
const intptr_t* arg_sizes = nullptr; const cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr;
size_t num_args = 0; size_t num_buffers_ = 0;
// Cardinality and size of temp buffers. // Entry parameter i is described by
// // buffer_infos[arg_index_table[i]].
// If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer. const int32* arg_index_table_ = nullptr;
//
// If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The // There are num_args entry parameters.
// corresponding entry in the temp buffer array needs to be set to null. int64 num_args_ = 0;
//
// If temp_sizes[i] < -1 then the i'th temp is the entry parameter
// -(temp_sizes[i] + 2).
const intptr_t* temp_sizes = nullptr;
size_t num_temps = 0;
// The 0-based index of the result tuple, in the temp buffers. // The 0-based index of the result tuple, in the temp buffers.
size_t result_index = 0; size_t result_index_ = 0;
// [Optional] Arrays of arg and result names. These are arrays of C-style // [Optional] Arrays of arg and result names. These are arrays of C-style
// strings, where the array is terminated by nullptr. // strings, where the array is terminated by nullptr.
const char** arg_names = nullptr; const char** arg_names_ = nullptr;
const char** result_names = nullptr; const char** result_names_ = nullptr;
// [Optional] Arg and result shapes. // [Optional] Arg and result shapes.
const xla::ProgramShape* program_shape = nullptr; const xla::ProgramShape* program_shape_ = nullptr;
// [Optional] Profile printer data. Null if profiling is disabled. // [Optional] Profile printer data. Null if profiling is disabled.
const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
// [Optional] The number of profile counters expected in the profile counter // [Optional] The number of profile counters expected in the profile counter
// buffer by the generated code and hlo_profile_printer. 0 if profiling is // buffer by the generated code and hlo_profile_printer. 0 if profiling is
// disabled. This information is already present in // disabled. This information is already present in
// hlo_profile_printer_data but xla::HloProfilePrinterData is forward // hlo_profile_printer_data but xla::HloProfilePrinterData is forward
// declared so we don't have access to that information here. // declared so we don't have access to that information here.
int64 profile_counters_size = 0; int64 profile_counters_size_ = 0;
// Only XlaCompiledCpuFunction is allowed to read the above fields.
friend class XlaCompiledCpuFunction;
}; };
// AllocMode controls the buffer allocation mode. // AllocMode controls the buffer allocation mode.
@ -135,14 +175,25 @@ class XlaCompiledCpuFunction {
// ------------------------------ // ------------------------------
// Arg methods for managing input buffers. Buffers are in row-major order. // Arg methods for managing input buffers. Buffers are in row-major order.
// Returns the underlying array of argument buffers, where args()[I] is the
// buffer for the positional argument at index I.
void** args() { return args_; }
const void* const* args() const { return args_; }
// Returns the buffer for the positional argument at the given `index`. // Returns the buffer for the positional argument at the given `index`.
void* arg_data(size_t index) { return args_[index]; } void* arg_data(size_t index) {
const void* arg_data(size_t index) const { return args_[index]; } return buffer_table_[arg_index_table_[index]];
}
const void* arg_data(size_t index) const {
return buffer_table_[arg_index_table_[index]];
}
int num_args() const { return num_args_; }
// Returns the size of entry parameter `idx`.
//
// There is a static version of this method on tfcompile generated subclasses
// of XlaCompiledCpuFunction, but try to prefer this when possible since it
// works both for XlaJitCompiledCpuFunction and AOT compiled subclasses.
int arg_size(int idx) const {
assert(idx < num_args());
return buffer_infos_[arg_index_table_[idx]].size();
}
// Sets the buffer for the positional argument at the given `index` to `data`. // Sets the buffer for the positional argument at the given `index` to `data`.
// Must be called before Run to have an effect. May be called under any // Must be called before Run to have an effect. May be called under any
@ -155,7 +206,9 @@ class XlaCompiledCpuFunction {
// //
// Aliasing of argument and result buffers is not allowed, and results in // Aliasing of argument and result buffers is not allowed, and results in
// undefined behavior. // undefined behavior.
void set_arg_data(size_t index, void* data) { args_[index] = data; } void set_arg_data(size_t index, void* data) {
buffer_table_[arg_index_table_[index]] = data;
}
// ------------------------------ // ------------------------------
// Result methods for managing output buffers. Buffers are in row-major order. // Result methods for managing output buffers. Buffers are in row-major order.
@ -165,9 +218,9 @@ class XlaCompiledCpuFunction {
// Returns the underlying array of result buffers, where results()[I] is the // Returns the underlying array of result buffers, where results()[I] is the
// buffer for the positional result at index I. // buffer for the positional result at index I.
void** results() { return static_cast<void**>(temps_[result_index_]); } void** results() { return static_cast<void**>(buffer_table_[result_index_]); }
const void* const* results() const { const void* const* results() const {
return static_cast<const void* const*>(temps_[result_index_]); return static_cast<const void* const*>(buffer_table_[result_index_]);
} }
// Profile counters for this XLA computation. // Profile counters for this XLA computation.
@ -225,25 +278,28 @@ class XlaCompiledCpuFunction {
const RawFunction raw_function_; const RawFunction raw_function_;
const size_t result_index_; const size_t result_index_;
// Arrays of argument and temp buffers; entries in args_ may be overwritten by // Array containing pointers to argument and temp buffers (slots corresponding
// the user. // to constant and on-stack buffers are null).
void** args_ = nullptr; void** const buffer_table_;
void** temps_ = nullptr;
// Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for // Describes the buffers used by the XLA computation.
// XLA generated code to be able to find it. const cpu_function_runtime::BufferInfo* const buffer_infos_;
// Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]]
// for XLA generated code to be able to find it.
// //
// For now we need to keep around the args_ array because there is code that // For now we need to keep around the args_ array because there is code that
// depends on args() returning a void**. However, in the future we may remove // depends on args() returning a void**. However, in the future we may remove
// args_ in favor of using temps_ as the sole storage for the arguments. // args_ in favor of using buffer_table_ as the sole storage for the
int32* arg_index_to_temp_index_; // arguments.
const int32* const arg_index_table_;
// The number of incoming arguments. // The number of incoming arguments.
int32 num_args_; const int32 num_args_;
// Backing memory for individual arg and temp buffers. // Backing memory for buffer_table_ and args_, the latter depending on
void* alloc_args_ = nullptr; // AllocMode.
void* alloc_temps_ = nullptr; void* alloc_buffer_table_ = nullptr;
// Backing memory for profiling counters. // Backing memory for profiling counters.
int64* profile_counters_ = nullptr; int64* profile_counters_ = nullptr;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
@ -35,45 +36,6 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace { namespace {
// Returns a vector of positional argument buffer sizes.
xla::StatusOr<std::vector<intptr_t>> ComputeArgSizes(
const xla::ProgramShape& program_shape) {
std::vector<intptr_t> arg_sizes;
const size_t num_args = program_shape.parameters_size();
arg_sizes.reserve(num_args);
for (int i = 0; i < num_args; ++i) {
const xla::Shape& arg_shape = program_shape.parameters(i);
constexpr size_t kPointerSize = sizeof(void*);
arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize));
}
return std::move(arg_sizes);
}
// Returns a vector of positional temporary buffer sizes.
xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes(
const xla::BufferAssignment& buffer_assignment) {
const std::vector<xla::BufferAllocation>& allocations =
buffer_assignment.Allocations();
std::vector<intptr_t> temp_sizes;
temp_sizes.reserve(allocations.size());
for (const xla::BufferAllocation& allocation : allocations) {
if (allocation.is_constant() || allocation.is_thread_local()) {
// Constants are lowered to globals. Thread locals are lowered to
// allocas.
temp_sizes.push_back(-1);
} else if (allocation.is_entry_computation_parameter()) {
// Entry computation parameters need some preprocessing in
// XlaCompiledCpuFunction::Run. See the comment on
// XlaCompiledCpuFunction::StaticData::temp_sizes.
temp_sizes.push_back(-allocation.parameter_number() - 2);
} else {
temp_sizes.push_back(allocation.size());
}
}
return std::move(temp_sizes);
}
// Returns the index of the result in the temp buffers. // Returns the index of the result in the temp buffers.
xla::StatusOr<size_t> ComputeResultIndex( xla::StatusOr<size_t> ComputeResultIndex(
const xla::BufferAssignment& buffer_assignment) { const xla::BufferAssignment& buffer_assignment) {
@ -157,11 +119,11 @@ XlaJitCompiledCpuFunction::Compile(
const xla::BufferAssignment& buffer_assignment = const xla::BufferAssignment& buffer_assignment =
cpu_executable->buffer_assignment(); cpu_executable->buffer_assignment();
// Compute buffer sizes and the result index, needed to run the raw function. // Compute buffer infos and the result index, needed to run the raw function.
TF_ASSIGN_OR_RETURN(std::vector<intptr_t> arg_sizes, std::vector<cpu_function_runtime::BufferInfo> buffer_infos =
ComputeArgSizes(*program_shape)); xla::cpu::CreateBufferInfosFromBufferAssignment(buffer_assignment);
TF_ASSIGN_OR_RETURN(std::vector<intptr_t> temp_sizes, std::vector<int32> arg_index_table =
ComputeTempSizes(buffer_assignment)); xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
TF_ASSIGN_OR_RETURN(size_t result_index, TF_ASSIGN_OR_RETURN(size_t result_index,
ComputeResultIndex(buffer_assignment)); ComputeResultIndex(buffer_assignment));
@ -169,28 +131,28 @@ XlaJitCompiledCpuFunction::Compile(
new XlaJitCompiledCpuFunction); new XlaJitCompiledCpuFunction);
XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get(); XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get();
jit->executable_ = std::move(executable); jit->executable_ = std::move(executable);
jit->arg_sizes_ = std::move(arg_sizes); jit->buffer_infos_ = std::move(buffer_infos);
jit->temp_sizes_ = std::move(temp_sizes); jit->arg_index_table_ = std::move(arg_index_table);
jit->program_shape_ = std::move(program_shape); jit->program_shape_ = std::move(program_shape);
jit->static_data_.raw_function = std::move(raw_function); jit->static_data_.set_raw_function(raw_function);
jit->static_data_.arg_sizes = jit->arg_sizes_.data(); jit->static_data_.set_buffer_infos(jit->buffer_infos_.data());
jit->static_data_.num_args = jit->arg_sizes_.size(); jit->static_data_.set_num_buffers(jit->buffer_infos_.size());
jit->static_data_.temp_sizes = jit->temp_sizes_.data(); jit->static_data_.set_arg_index_table(jit->arg_index_table_.data());
jit->static_data_.num_temps = jit->temp_sizes_.size(); jit->static_data_.set_num_args(jit->arg_index_table_.size());
jit->static_data_.result_index = result_index; jit->static_data_.set_result_index(result_index);
// Optional metadata is collected and set below. // Optional metadata is collected and set below.
CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_);
CollectNames(config.fetch(), &jit->nonempty_result_names_, CollectNames(config.fetch(), &jit->nonempty_result_names_,
&jit->result_names_); &jit->result_names_);
jit->static_data_.arg_names = jit->arg_names_.data(); jit->static_data_.set_arg_names(jit->arg_names_.data());
jit->static_data_.result_names = jit->result_names_.data(); jit->static_data_.set_result_names(jit->result_names_.data());
jit->static_data_.program_shape = jit->program_shape_.get(); jit->static_data_.set_program_shape(jit->program_shape_.get());
if (cpu_executable->hlo_profiling_enabled()) { if (cpu_executable->hlo_profiling_enabled()) {
jit->static_data_.hlo_profile_printer_data = jit->static_data_.set_hlo_profile_printer_data(
&cpu_executable->hlo_profile_printer_data(); &cpu_executable->hlo_profile_printer_data());
jit->static_data_.profile_counters_size = jit->static_data_.set_profile_counters_size(
cpu_executable->hlo_profile_printer_data().profile_counters_size(); cpu_executable->hlo_profile_printer_data().profile_counters_size());
} }
return std::move(jit_unique_ptr); return std::move(jit_unique_ptr);

View File

@ -66,9 +66,11 @@ class XlaJitCompiledCpuFunction {
// The static data is backed by the rest of the state in this class. // The static data is backed by the rest of the state in this class.
XlaCompiledCpuFunction::StaticData static_data_; XlaCompiledCpuFunction::StaticData static_data_;
// The backing arrays of arg and temp buffer sizes. // The backing array for buffer infos.
std::vector<intptr_t> arg_sizes_; std::vector<cpu_function_runtime::BufferInfo> buffer_infos_;
std::vector<intptr_t> temp_sizes_;
// The backing array for the arg index table.
std::vector<int32> arg_index_table_;
// The backing arrays of arg and result names. We hold the actual strings in // The backing arrays of arg and result names. We hold the actual strings in
// nonempty_*_names_, and hold arrays of pointers in *_names_ for the static // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static

View File

@ -409,7 +409,7 @@ class Array {
// Returns the total number of elements in the array. // Returns the total number of elements in the array.
int64 num_elements() const { int64 num_elements() const {
return std::accumulate(sizes_.begin(), sizes_.end(), 1, return std::accumulate(sizes_.begin(), sizes_.end(), 1LL,
std::multiplies<int64>()); std::multiplies<int64>());
} }

View File

@ -121,6 +121,30 @@ xla_test(
], ],
) )
cc_library(
name = "pooling",
srcs = ["pooling.cc"],
hdrs = ["pooling.h"],
deps = [
":arithmetic",
":constants",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:lib",
],
)
xla_test(
name = "pooling_test",
srcs = ["pooling_test.cc"],
deps = [
":pooling",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
cc_library( cc_library(
name = "prng", name = "prng",
srcs = ["prng.cc"], srcs = ["prng.cc"],
@ -144,7 +168,7 @@ cc_library(
":numeric", ":numeric",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client:xla_builder",
], ],
) )
@ -161,7 +185,7 @@ xla_test(
"//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/compiler/xla/tests:xla_internal_test_main",
], ],

View File

@ -0,0 +1,183 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/pooling.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
namespace xla {
namespace {
// Common computation shared between AvgPool and AvgPoolGrad. Divide each
// element of an image by the count of elements that contributed to that
// element during pooling.
XlaOp AvgPoolDivideByCountWithGeneralPadding(
XlaOp sums, PrimitiveType dtype,
tensorflow::gtl::ArraySlice<int64> input_shape,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
tensorflow::gtl::ArraySlice<int64> ksize,
tensorflow::gtl::ArraySlice<int64> stride,
const TensorFormat& data_format) {
// The padding shouldn't be included in the counts. We use another
// ReduceWindow to find the right counts.
const int num_spatial_dims = spatial_padding.size();
std::vector<int64> input_dim_sizes(num_spatial_dims);
std::vector<int64> window_dims(num_spatial_dims);
std::vector<int64> window_ksize(num_spatial_dims);
std::vector<int64> window_stride(num_spatial_dims);
CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
<< "Invalid number of spatial dimentions in data format specification";
for (int i = 0; i < num_spatial_dims; ++i) {
int dim = data_format.spatial_dimension(i);
input_dim_sizes[i] = input_shape[dim];
window_dims[i] = dim;
window_ksize[i] = ksize[dim];
window_stride[i] = stride[dim];
}
XlaBuilder* b = sums.builder();
// Build a matrix of all 1s, with the same width/height as the input.
auto ones = Broadcast(One(b, dtype), input_dim_sizes);
PaddingConfig padding_config;
for (int i = 0; i < num_spatial_dims; ++i) {
auto dims = padding_config.add_dimensions();
dims->set_edge_padding_low(spatial_padding[i].first);
dims->set_edge_padding_high(spatial_padding[i].second);
}
auto zero = Zero(b, dtype);
auto padded_ones = Pad(ones, zero, padding_config);
// Perform a ReduceWindow with the same window size, strides, and padding
// to count the number of contributions to each result element.
auto counts =
ReduceWindow(padded_ones, zero, CreateScalarAddComputation(dtype, b),
window_ksize, window_stride, Padding::kValid);
return Div(sums, counts, window_dims);
}
// Sums all elements in the window specified by 'kernel_size' and 'stride'.
XlaOp ComputeSums(XlaOp operand, XlaOp init_value,
tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride,
const TensorFormat& data_format) {
XlaBuilder* b = operand.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value));
PrimitiveType accumulation_type = init_shape.element_type();
auto add_computation = CreateScalarAddComputation(accumulation_type, b);
return ReduceWindow(operand, init_value, add_computation, kernel_size,
stride, Padding::kValid);
});
}
// Creates a padding configuration out of spatial padding values.
PaddingConfig MakeSpatialPaddingConfig(
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride,
const TensorFormat& data_format) {
const int num_spatial_dims = kernel_size.size() - 2;
PaddingConfig padding_config;
for (int i = 0; i < 2 + num_spatial_dims; ++i) {
padding_config.add_dimensions();
}
CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
<< "Invalid number of spatial dimentions in data format specification";
for (int i = 0; i < num_spatial_dims; ++i) {
int dim = data_format.spatial_dimension(i);
auto padding_dimension = padding_config.mutable_dimensions(dim);
padding_dimension->set_edge_padding_low(spatial_padding[i].first);
padding_dimension->set_edge_padding_high(spatial_padding[i].second);
}
return padding_config;
}
} // namespace
XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
const TensorFormat& data_format) {
XlaBuilder* b = operand.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
PrimitiveType dtype = operand_shape.element_type();
auto max_computation = CreateScalarMaxComputation(dtype, b);
auto init_value = MinValue(b, dtype);
return ReduceWindow(operand, init_value, max_computation, kernel_size,
stride, padding);
});
}
XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const TensorFormat& data_format,
const bool counts_include_padding) {
XlaBuilder* b = operand.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
PrimitiveType dtype = operand_shape.element_type();
auto init_value = Zero(b, dtype);
std::vector<int64> input_size(operand_shape.dimensions().begin(),
operand_shape.dimensions().end());
auto padding_config =
MakeSpatialPaddingConfig(padding, kernel_size, stride, data_format);
auto padded_operand = Pad(operand, Zero(b, dtype), padding_config);
auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride,
data_format);
if (counts_include_padding) {
// If counts include padding, all windows have the same number of elements
// contributing to each average. Divide by the window size everywhere to
// get the average.
int64 window_size =
std::accumulate(kernel_size.begin(), kernel_size.end(), 1,
[](int64 x, int64 y) { return x * y; });
auto divisor = ConstantR0WithType(b, dtype, window_size);
return pooled / divisor;
} else {
return AvgPoolDivideByCountWithGeneralPadding(
pooled, dtype, input_size, padding, kernel_size, stride, data_format);
}
});
}
std::vector<std::pair<int64, int64>> MakeSpatialPadding(
tensorflow::gtl::ArraySlice<int64> input_size,
tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
const TensorFormat& data_format) {
const int num_spatial_dims = kernel_size.size() - 2;
std::vector<int64> input_spatial_dimensions;
std::vector<int64> kernel_size_spatial_dimensions;
std::vector<int64> stride_spatial_dimensions;
CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
<< "Invalid number of spatial dimentions in data format specification";
for (int i = 0; i < num_spatial_dims; ++i) {
int dim = data_format.spatial_dimension(i);
input_spatial_dimensions.push_back(input_size[dim]);
kernel_size_spatial_dimensions.push_back(kernel_size[dim]);
stride_spatial_dimensions.push_back(stride[dim]);
}
return MakePadding(input_spatial_dimensions, kernel_size_spatial_dimensions,
stride_spatial_dimensions, padding);
}
} // namespace xla

View File

@ -0,0 +1,73 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace xla {
// Tensor format for reduce window operations.
class TensorFormat {
public:
TensorFormat(int batch_dimension, int feature_dimension,
tensorflow::gtl::ArraySlice<int64> spatial_dimensions)
: batch_dimension_(batch_dimension),
feature_dimension_(feature_dimension),
spatial_dimensions_(spatial_dimensions.begin(),
spatial_dimensions.end()) {}
int batch_dimension() const { return batch_dimension_; }
int feature_dimension() const { return feature_dimension_; }
int spatial_dimension(int dim) const { return spatial_dimensions_[dim]; }
int num_spatial_dims() const { return spatial_dimensions_.size(); }
private:
// The number of the dimension that represents the batch.
int batch_dimension_;
// The number of the dimension that represents the features.
int feature_dimension_;
// The dimension numbers for the spatial dimensions.
tensorflow::gtl::InlinedVector<int, 4> spatial_dimensions_;
};
// Computes the max pool of 'operand'.
XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
const TensorFormat& data_format);
// Computes the average pool of 'operand'.
XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const TensorFormat& data_format,
const bool counts_include_padding);
// Returns the list of low and high padding elements in each spatial dimension
// for the given 'padding' specification.
std::vector<std::pair<int64, int64>> MakeSpatialPadding(
tensorflow::gtl::ArraySlice<int64> input_size,
tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
const TensorFormat& data_format);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_

View File

@ -0,0 +1,185 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/pooling.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
namespace xla {
namespace {
TensorFormat MakeNCHWFormat(int num_spatial_dims) {
tensorflow::gtl::InlinedVector<int64, 4> spatial_dimensions;
for (int i = 0; i < num_spatial_dims; ++i) {
spatial_dimensions.push_back(i + 2);
}
return TensorFormat(/*batch_dimension=*/0, /*feature_dimension=*/1,
/*spatial_dimensions=*/spatial_dimensions);
}
std::vector<std::pair<int64, int64>> MakeGeneralPadding(
XlaOp input, tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
const xla::TensorFormat& data_format) {
XlaBuilder* b = input.builder();
Shape operand_shape = b->GetShape(input).ValueOrDie();
std::vector<int64> input_size(operand_shape.dimensions().begin(),
operand_shape.dimensions().end());
return MakeSpatialPadding(input_size, kernel_size, stride, padding,
data_format);
}
// Add singleton batch and feature dimensions to spatial dimensions, according
// to 'data_format' specification.
std::vector<int64> ExpandWithBatchAndFeatureDimensions(
tensorflow::gtl::ArraySlice<int64> spatial_dim_sizes,
const xla::TensorFormat& data_format) {
const int num_spatial_dims = spatial_dim_sizes.size();
std::vector<int64> tensor_sizes(num_spatial_dims + 2, 1);
for (int i = 0; i < num_spatial_dims; ++i) {
int dim = data_format.spatial_dimension(i);
tensor_sizes[dim] = spatial_dim_sizes[i];
}
return tensor_sizes;
}
class PoolingTest : public ClientLibraryTestBase {
public:
ErrorSpec error_spec_{0.0001};
};
XLA_TEST_F(PoolingTest, MaxPool2D) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
auto stride = kernel_size;
MaxPool(input, kernel_size, stride, Padding::kValid, data_format);
ComputeAndCompareR4<float>(&builder, {{{{5, 4}}}}, {}, error_spec_);
}
XLA_TEST_F(PoolingTest, MaxPool2DWithPadding) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
auto stride = kernel_size;
MaxPool(input, kernel_size, stride, Padding::kSame, data_format);
ComputeAndCompareR4<float>(&builder, {{{{5, 4, 5}}}}, {}, error_spec_);
}
XLA_TEST_F(PoolingTest, MaxPool2DWithPaddingAndStride) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
MaxPool(input, kernel_size, stride, Padding::kSame, data_format);
ComputeAndCompareR4<float>(&builder, {{{{5, 4, 4, 5, 5}, {5, 4, 3, 2, 1}}}},
{}, error_spec_);
}
XLA_TEST_F(PoolingTest, AvgPool2D) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
auto stride = kernel_size;
auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kValid,
data_format);
AvgPool(input, kernel_size, stride, padding, data_format,
/*counts_include_padding=*/true);
ComputeAndCompareR4<float>(&builder, {{{{3, 3}}}}, {}, error_spec_);
}
XLA_TEST_F(PoolingTest, AvgPool2DWithPadding) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
auto stride = kernel_size;
auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame,
data_format);
AvgPool(input, kernel_size, stride, padding, data_format,
/*counts_include_padding=*/false);
ComputeAndCompareR4<float>(&builder, {{{{3, 3, 3}}}}, {}, error_spec_);
}
XLA_TEST_F(PoolingTest, AvgPool2DWithPaddingAndStride) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame,
data_format);
AvgPool(input, kernel_size, stride, padding, data_format,
/*counts_include_padding=*/false);
ComputeAndCompareR4<float>(&builder,
{{{{3, 3, 3, 3, 3}, {4.5, 3.5, 2.5, 1.5, 1}}}}, {},
error_spec_);
}
XLA_TEST_F(PoolingTest, AvgPool2DWithGeneralPaddingCountNotIncludePadding) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format);
auto stride = kernel_size;
AvgPool(input, kernel_size, stride, {{1, 1}, {2, 1}}, data_format,
/*counts_include_padding=*/false);
ComputeAndCompareR4<float>(&builder, {{{{3, 3}}}}, {}, error_spec_);
}
XLA_TEST_F(PoolingTest,
AvgPool2DWithGeneralPaddingCountNotIncludePaddingAndStride) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format);
auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
AvgPool(input, kernel_size, stride, {{2, 1}, {1, 1}}, data_format,
/*counts_include_padding=*/false);
ComputeAndCompareR4<float>(&builder, {{{{1.5, 3, 4.5}, {3, 3, 3}}}}, {},
error_spec_);
}
} // namespace
} // namespace xla

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_macros.h"

View File

@ -303,7 +303,7 @@ StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
const Shape& shape, int device_ordinal) { const Shape& shape, int device_ordinal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device_ordinal)); backend().stream_executor(device_ordinal));
auto literal = MakeUnique<Literal>(); auto literal = Literal::CreateFromShape(shape);
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
executor, shape, literal.get())); executor, shape, literal.get()));
return std::move(literal); return std::move(literal);

View File

@ -45,21 +45,6 @@ int64 GetUniqueId() {
return id; return id;
} }
// Returns true if an instruction with the given opcode can be the root of the
// computation.
bool CanBeRoot(HloOpcode opcode) {
switch (opcode) {
case HloOpcode::kAfterAll:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
return false;
default:
return true;
}
}
} // namespace } // namespace
XlaOp operator-(const XlaOp& x) { return Neg(x); } XlaOp operator-(const XlaOp& x) { return Neg(x); }
@ -142,28 +127,13 @@ XlaOp XlaBuilder::ReportErrorOrReturn(
return ReportErrorOrReturn(op_creator()); return ReportErrorOrReturn(op_creator());
} }
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const { StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
TF_RETURN_IF_ERROR(first_error_); TF_RETURN_IF_ERROR(first_error_);
TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size()));
TF_RET_CHECK(root_id != nullptr);
ProgramShape program_shape; ProgramShape program_shape;
// Not all instructions can be roots. Walk backwards from the last added *program_shape.mutable_result() = instructions_[root_id].shape();
// instruction until a valid root is found.
int64 index = instructions_.size() - 1;
for (; index >= 0; index--) {
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
StringToHloOpcode(instructions_[index].opcode()));
if (CanBeRoot(opcode)) {
break;
}
}
if (index < 0) {
return FailedPrecondition("no root instruction was found");
}
*root_id = instructions_[index].id();
*program_shape.mutable_result() = instructions_[index].shape();
// Check that the parameter numbers are continuous from 0, and add parameter // Check that the parameter numbers are continuous from 0, and add parameter
// shapes and names to the program shape. // shapes and names to the program shape.
@ -188,8 +158,15 @@ StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const {
} }
StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const { StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
int64 root; TF_RET_CHECK(!instructions_.empty());
return GetProgramShape(&root); return GetProgramShape(instructions_.back().id());
}
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
if (root.builder_ != this) {
return InvalidArgument("Given root operation is not in this computation.");
}
return GetProgramShape(root.handle());
} }
void XlaBuilder::IsConstantVisitor(const int64 op_handle, void XlaBuilder::IsConstantVisitor(const int64 op_handle,
@ -257,17 +234,29 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
return AppendStatus(first_error_, backtrace); return AppendStatus(first_error_, backtrace);
} }
return Build(instructions_.back().id());
}
StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root) {
if (root.builder_ != this) {
return InvalidArgument("Given root operation is not in this computation.");
}
return Build(root.handle());
}
StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
if (!first_error_.ok()) {
string backtrace;
first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
return AppendStatus(first_error_, backtrace);
}
HloComputationProto entry; HloComputationProto entry;
entry.set_id(GetUniqueId()); // Give the computation a global unique id. entry.set_id(GetUniqueId()); // Give the computation a global unique id.
entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique. entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique.
{ TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id));
int64 root_id;
TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(),
GetProgramShape(&root_id));
entry.set_root_id(root_id); entry.set_root_id(root_id);
}
for (auto& instruction : instructions_) { for (auto& instruction : instructions_) {
// Ensures that the instruction names are unique among the whole graph. // Ensures that the instruction names are unique among the whole graph.
@ -1099,11 +1088,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
sharding_builder::AssignDevice(0); sharding_builder::AssignDevice(0);
XlaScopedShardingAssignment scoped_sharding(this, XlaScopedShardingAssignment scoped_sharding(this,
infeed_instruction_sharding); infeed_instruction_sharding);
TF_ASSIGN_OR_RETURN(infeed, TF_ASSIGN_OR_RETURN(
AddInstruction(std::move(instr), HloOpcode::kInfeed)); infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {}));
} else { } else {
TF_ASSIGN_OR_RETURN(infeed, TF_ASSIGN_OR_RETURN(
AddInstruction(std::move(instr), HloOpcode::kInfeed)); infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {}));
} }
// The infeed instruction produces a tuple of the infed data and a token // The infeed instruction produces a tuple of the infed data and a token
@ -1892,6 +1881,61 @@ XlaOp XlaBuilder::CrossReplicaSum(
}); });
} }
XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
// The HloInstruction for Alltoall currently only handles the data
// communication: it accepts N already split parts and scatters them to N
// cores, and each core gathers the N received parts into a tuple as the
// output. So here we explicitly split the operand before the hlo alltoall,
// and concat the tuple elements.
//
// First, run shape inference to make sure the shapes are valid.
TF_RETURN_IF_ERROR(
ShapeInference::InferAllToAllShape(operand_shape, split_dimension,
concat_dimension, split_count)
.status());
// Split into N parts.
std::vector<XlaOp> slices;
slices.reserve(split_count);
const int64 block_size =
operand_shape.dimensions(split_dimension) / split_count;
for (int i = 0; i < split_count; i++) {
slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
/*limit_index=*/(i + 1) * block_size,
/*stride=*/1, /*dimno=*/split_dimension));
}
// Handle data communication.
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
std::vector<const Shape*> slice_shape_ptrs;
c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
}
TF_ASSIGN_OR_RETURN(
XlaOp alltoall,
AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices));
// Concat the N received parts.
std::vector<XlaOp> received;
received.reserve(split_count);
for (int i = 0; i < split_count; i++) {
received.push_back(this->GetTupleElement(alltoall, i));
}
return this->ConcatInDim(received, concat_dimension);
});
}
XlaOp XlaBuilder::SelectAndScatter( XlaOp XlaBuilder::SelectAndScatter(
const XlaOp& operand, const XlaComputation& select, const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_dimensions,
@ -2163,11 +2207,6 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
LookUpInstruction(root_op)); LookUpInstruction(root_op));
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
if (!CanBeRoot(opcode)) {
return InvalidArgument("the operand with opcode %s cannot be root",
root->opcode().c_str());
}
HloComputationProto entry; HloComputationProto entry;
entry.set_id(GetUniqueId()); // Give the computation a global unique id. entry.set_id(GetUniqueId()); // Give the computation a global unique id.
@ -2693,6 +2732,13 @@ XlaOp CrossReplicaSum(
replica_group_ids, channel_id); replica_group_ids, channel_id);
} }
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups) {
return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
split_count, replica_groups);
}
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<int64> window_strides,

View File

@ -195,9 +195,14 @@ class XlaBuilder {
// Builds the computation with the requested operations, or returns a non-ok // Builds the computation with the requested operations, or returns a non-ok
// status. Note that all ops that have been enqueued will be moved to the // status. Note that all ops that have been enqueued will be moved to the
// computation being returned. // computation being returned. The root of the computation will be the last
// added operation.
StatusOr<XlaComputation> Build(); StatusOr<XlaComputation> Build();
// Overload of Build which specifies a particular root instruction for the
// computation.
StatusOr<XlaComputation> Build(XlaOp root);
// Builds the computation with the requested operations, or notes an error in // Builds the computation with the requested operations, or notes an error in
// the parent XlaBuilder and returns an empty computation if building failed. // the parent XlaBuilder and returns an empty computation if building failed.
// This function is intended to be used where the returned XlaComputation is // This function is intended to be used where the returned XlaComputation is
@ -225,9 +230,14 @@ class XlaBuilder {
// Returns the shape of the given op. // Returns the shape of the given op.
StatusOr<Shape> GetShape(const XlaOp& op) const; StatusOr<Shape> GetShape(const XlaOp& op) const;
// Returns the (inferred) result for the current computation's shape. // Returns the (inferred) result for the current computation's shape. This
// assumes the root instruction is the last added instruction.
StatusOr<ProgramShape> GetProgramShape() const; StatusOr<ProgramShape> GetProgramShape() const;
// Returns the (inferred) result for the current computation's shape using the
// given operation as the root.
StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
// Reports an error to the builder, by // Reports an error to the builder, by
// * storing it internally and capturing a backtrace if it's the first error // * storing it internally and capturing a backtrace if it's the first error
// (this deferred value will be produced on the call to // (this deferred value will be produced on the call to
@ -255,6 +265,9 @@ class XlaBuilder {
StatusOr<bool> IsConstant(const XlaOp& operand) const; StatusOr<bool> IsConstant(const XlaOp& operand) const;
private: private:
// Build helper which takes the id of the root operation..
StatusOr<XlaComputation> Build(int64 root_id);
// Enqueues a "retrieve parameter value" instruction for a parameter that was // Enqueues a "retrieve parameter value" instruction for a parameter that was
// passed to the computation. // passed to the computation.
XlaOp Parameter(int64 parameter_number, const Shape& shape, XlaOp Parameter(int64 parameter_number, const Shape& shape,
@ -686,9 +699,9 @@ class XlaBuilder {
// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
// //
// - `channel_id`: for Allreduce nodes from different models, if they have the // - `channel_id`: for Allreduce nodes from different modules, if they have
// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will
// applied cross models. // not be applied cross modules.
// //
// TODO(b/79737069): Rename this to AllReduce when it's ready to use. // TODO(b/79737069): Rename this to AllReduce when it's ready to use.
XlaOp CrossReplicaSum( XlaOp CrossReplicaSum(
@ -697,6 +710,13 @@ class XlaBuilder {
const tensorflow::gtl::optional<ChannelHandle>& channel_id = const tensorflow::gtl::optional<ChannelHandle>& channel_id =
tensorflow::gtl::nullopt); tensorflow::gtl::nullopt);
// Enqueues an operation that do an Alltoall of the operand cross cores.
//
// TODO(b/110096724): This is NOT YET ready to use.
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
// Enqueues an operation that scatters the `source` array to the selected // Enqueues an operation that scatters the `source` array to the selected
// indices of each window. // indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
@ -969,9 +989,8 @@ class XlaBuilder {
// shape. // shape.
StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand); StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand);
// Returns the (inferred) result for the program shape for the current // Returns the (inferred) result for the program shape using the given root.
// computation and fills the root_id in the pointer. StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
StatusOr<ProgramShape> GetProgramShape(int64* root_id) const;
// Returns shapes for the operands. // Returns shapes for the operands.
StatusOr<std::vector<Shape>> GetOperandShapes( StatusOr<std::vector<Shape>> GetOperandShapes(
@ -1234,6 +1253,9 @@ class XlaBuilder {
const XlaOp& operand, const XlaComputation& computation, const XlaOp& operand, const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids, tensorflow::gtl::ArraySlice<int64> replica_group_ids,
const tensorflow::gtl::optional<ChannelHandle>& channel_id); const tensorflow::gtl::optional<ChannelHandle>& channel_id);
friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
friend XlaOp SelectAndScatter( friend XlaOp SelectAndScatter(
const XlaOp& operand, const XlaComputation& select, const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_dimensions,
@ -1820,9 +1842,9 @@ XlaOp CrossReplicaSum(
// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
// //
// - `channel_id`: for Allreduce nodes from different models, if they have the // - `channel_id`: for Allreduce nodes from different modules, if they have the
// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
// applied cross models. // applied cross modules.
// //
// TODO(b/79737069): Rename this to AllReduce when it's ready to use. // TODO(b/79737069): Rename this to AllReduce when it's ready to use.
XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
@ -1830,6 +1852,13 @@ XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
const tensorflow::gtl::optional<ChannelHandle>& const tensorflow::gtl::optional<ChannelHandle>&
channel_id = tensorflow::gtl::nullopt); channel_id = tensorflow::gtl::nullopt);
// Enqueues an operation that do an Alltoall of the operand cross cores.
//
// TODO(b/110096724): This is NOT YET ready to use.
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups = {});
// Enqueues an operation that scatters the `source` array to the selected // Enqueues an operation that scatters the `source` array to the selected
// indices of each window. // indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla { namespace xla {
@ -46,6 +47,17 @@ class XlaBuilderTest : public ::testing::Test {
return HloModule::CreateFromProto(proto, config); return HloModule::CreateFromProto(proto, config);
} }
// Overload which explicitly specifies the root instruction.
StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b,
XlaOp root) {
TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build(root));
const HloModuleProto& proto = computation.proto();
TF_ASSIGN_OR_RETURN(const auto& config,
HloModule::CreateModuleConfigFromProto(
proto, legacy_flags::GetDebugOptionsFromFlags()));
return HloModule::CreateFromProto(proto, config);
}
// Returns the name of the test currently being run. // Returns the name of the test currently being run.
string TestName() const { string TestName() const {
return ::testing::UnitTest::GetInstance()->current_test_info()->name(); return ::testing::UnitTest::GetInstance()->current_test_info()->name();
@ -293,6 +305,21 @@ TEST_F(XlaBuilderTest, Transpose) {
EXPECT_THAT(root, op::Transpose(op::Parameter())); EXPECT_THAT(root, op::Transpose(op::Parameter()));
} }
TEST_F(XlaBuilderTest, AllToAll) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0,
/*split_count=*/2);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
// AllToAll is decomposed into slices -> all-to-all -> gte -> concat.
EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate);
EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll);
EXPECT_TRUE(
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
}
TEST_F(XlaBuilderTest, ReportError) { TEST_F(XlaBuilderTest, ReportError) {
XlaBuilder b(TestName()); XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
@ -320,5 +347,45 @@ TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) {
EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
} }
TEST_F(XlaBuilderTest, BuildWithSpecificRoot) {
XlaBuilder b(TestName());
XlaOp constant = ConstantR0<float>(&b, 1.0);
Add(constant, ConstantR0<float>(&b, 2.0));
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Constant());
}
TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) {
// Specifying a particular root in Build should still include all entry
// parameters.
XlaBuilder b(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
XlaOp x = Parameter(&b, 0, shape, "x");
XlaOp y = Parameter(&b, 1, shape, "y");
XlaOp z = Parameter(&b, 2, shape, "z");
Add(x, Sub(y, z));
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Parameter());
EXPECT_EQ(module->entry_computation()->num_parameters(), 3);
EXPECT_EQ(module->entry_computation()->instruction_count(), 5);
}
TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) {
XlaBuilder b(TestName());
XlaBuilder other_b(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
Parameter(&b, 0, shape, "param");
XlaOp other_param = Parameter(&other_b, 0, shape, "other_param");
Status status = b.Build(other_param).status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(
status.error_message(),
::testing::HasSubstr("root operation is not in this computation"));
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -1,33 +0,0 @@
# Description:
# The new XLA client libraries.
licenses(["notice"]) # Apache 2.0
package(default_visibility = [":friends"])
package_group(
name = "friends",
includes = [
"//tensorflow/compiler/xla:friends",
],
)
# Filegroup used to collect source files for dependency checking.
filegroup(
name = "c_srcs",
data = glob([
"**/*.cc",
"**/*.h",
]),
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
cc_library(
name = "xla_builder",
hdrs = ["xla_builder.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla/client:xla_builder",
],
)

View File

@ -71,7 +71,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) {
return out; return out;
} }
Literal::StrideConfig::StrideConfig( MutableLiteralBase::StrideConfig::StrideConfig(
const Shape& source_shape, const Shape& dest_shape, const Shape& source_shape, const Shape& dest_shape,
tensorflow::gtl::ArraySlice<int64> dimensions) tensorflow::gtl::ArraySlice<int64> dimensions)
: dimensions(dimensions), : dimensions(dimensions),
@ -133,7 +133,8 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
} }
Literal::Literal(const Shape& shape, bool allocate_arrays) Literal::Literal(const Shape& shape, bool allocate_arrays)
: LiteralBase(), shape_(MakeUnique<Shape>(shape)) { : MutableLiteralBase() {
shape_ = MakeUnique<Shape>(shape);
CHECK(LayoutUtil::HasLayout(*shape_)); CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece(); root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get()); root_piece_->set_subshape(shape_.get());
@ -159,7 +160,9 @@ void Literal::DeallocateBuffers() {
}); });
} }
Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } Literal::Literal(Literal&& other) : MutableLiteralBase() {
*this = std::move(other);
}
Literal& Literal::operator=(Literal&& other) { Literal& Literal::operator=(Literal&& other) {
DCHECK(&other.root_piece_->subshape() == other.shape_.get()); DCHECK(&other.root_piece_->subshape() == other.shape_.get());
@ -187,12 +190,13 @@ const SparseIndexArray* LiteralBase::sparse_indices(
return piece(shape_index).sparse_indices(); return piece(shape_index).sparse_indices();
} }
SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { SparseIndexArray* MutableLiteralBase::sparse_indices(
const ShapeIndex& shape_index) {
return piece(shape_index).sparse_indices(); return piece(shape_index).sparse_indices();
} }
template <typename NativeT> template <typename NativeT>
Status Literal::CopySliceFromInternal( Status MutableLiteralBase::CopySliceFromInternal(
const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base, const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
tensorflow::gtl::ArraySlice<int64> dest_base, tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size) { tensorflow::gtl::ArraySlice<int64> copy_size) {
@ -225,7 +229,7 @@ Status Literal::CopySliceFromInternal(
// proper stride size at the matching dimension. // proper stride size at the matching dimension.
DimensionVector src_indexes(src_base.size(), 0); DimensionVector src_indexes(src_base.size(), 0);
DimensionVector dest_indexes(dest_base.size(), 0); DimensionVector dest_indexes(dest_base.size(), 0);
Literal::StrideConfig stride_config(src_literal.shape(), shape(), MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
copy_size); copy_size);
auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) { auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
@ -253,7 +257,8 @@ Status Literal::CopySliceFromInternal(
return Status::OK(); return Status::OK();
} }
Status Literal::CopyElementFrom(const LiteralSlice& src_literal, Status MutableLiteralBase::CopyElementFrom(
const LiteralSlice& src_literal,
tensorflow::gtl::ArraySlice<int64> src_index, tensorflow::gtl::ArraySlice<int64> src_index,
tensorflow::gtl::ArraySlice<int64> dest_index) { tensorflow::gtl::ArraySlice<int64> dest_index) {
DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
@ -275,8 +280,8 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
return Status::OK(); return Status::OK();
} }
/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto( /* static */ StatusOr<std::unique_ptr<Literal>>
const LiteralProto& proto) { MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
if (!proto.has_shape()) { if (!proto.has_shape()) {
return InvalidArgument("LiteralProto has no shape"); return InvalidArgument("LiteralProto has no shape");
} }
@ -405,7 +410,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
return Status::OK(); return Status::OK();
} }
Status Literal::CopyFrom(const LiteralSlice& src_literal, Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
const ShapeIndex& dest_shape_index, const ShapeIndex& dest_shape_index,
const ShapeIndex& src_shape_index) { const ShapeIndex& src_shape_index) {
const Shape& dest_subshape = const Shape& dest_subshape =
@ -482,7 +487,8 @@ Status Literal::MoveFrom(Literal&& src_literal,
return Status::OK(); return Status::OK();
} }
Status Literal::CopySliceFrom(const LiteralSlice& src_literal, Status MutableLiteralBase::CopySliceFrom(
const LiteralSlice& src_literal,
tensorflow::gtl::ArraySlice<int64> src_base, tensorflow::gtl::ArraySlice<int64> src_base,
tensorflow::gtl::ArraySlice<int64> dest_base, tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size) { tensorflow::gtl::ArraySlice<int64> copy_size) {
@ -543,7 +549,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
shape().element_type()); shape().element_type());
} }
void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
CHECK(ShapeUtil::IsArray(shape())); CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(element_count(), values.bits()); CHECK_EQ(element_count(), values.bits());
@ -895,8 +901,8 @@ size_t LiteralBase::Hash() const {
return hash_value; return hash_value;
} }
Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index, Status MutableLiteralBase::SetIntegralAsS64(
int64 value) { tensorflow::gtl::ArraySlice<int64> multi_index, int64 value) {
CHECK(LayoutUtil::IsDenseArray(shape())); CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) { switch (shape().element_type()) {
case PRED: case PRED:
@ -933,7 +939,7 @@ tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
return p.sparse_indices()->At(sparse_element_number); return p.sparse_indices()->At(sparse_element_number);
} }
void Literal::SortSparseElements(const ShapeIndex& shape_index) { void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) {
piece(shape_index).SortSparseElements(); piece(shape_index).SortSparseElements();
} }
@ -1391,11 +1397,11 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
elements.push_back(std::move(*new_element)); elements.push_back(std::move(*new_element));
} }
auto converted = MakeUnique<Literal>(); auto converted = MakeUnique<Literal>();
*converted = Literal::MoveIntoTuple(&elements); *converted = MutableLiteralBase::MoveIntoTuple(&elements);
return std::move(converted); return std::move(converted);
} }
/* static */ Literal Literal::MoveIntoTuple( /* static */ Literal MutableLiteralBase::MoveIntoTuple(
tensorflow::gtl::MutableArraySlice<Literal> elements) { tensorflow::gtl::MutableArraySlice<Literal> elements) {
std::vector<Shape> element_shapes; std::vector<Shape> element_shapes;
for (const Literal& element : elements) { for (const Literal& element : elements) {
@ -1808,7 +1814,8 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
} // namespace } // namespace
Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
// These conditions should have been checked in Literal::CreateFromProto. // These conditions should have been checked in
// MutableLiteralBase::CreateFromProto.
TF_RET_CHECK(proto.has_shape()); TF_RET_CHECK(proto.has_shape());
TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
@ -1900,7 +1907,7 @@ const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
return piece(shape_index).untyped_data(); return piece(shape_index).untyped_data();
} }
void* Literal::untyped_data(const ShapeIndex& shape_index) { void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
return piece(shape_index).untyped_data(); return piece(shape_index).untyped_data();
} }
@ -1916,6 +1923,127 @@ string LiteralBase::GetR1U8AsString() const {
ShapeUtil::ElementsIn(shape())); ShapeUtil::ElementsIn(shape()));
} }
void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
Piece* src_piece,
Piece* dest_piece) {
DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
<< "src_piece has shape: "
<< ShapeUtil::HumanString(src_piece->subshape())
<< "dest_piece has shape: "
<< ShapeUtil::HumanString(dest_piece->subshape());
if (ShapeUtil::IsTuple(shape)) {
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
const Shape& subshape = shape.tuple_shapes(i);
auto child_piece = Piece();
child_piece.set_subshape(&subshape);
CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
dest_piece->emplace_back(std::move(child_piece));
}
} else if (ShapeUtil::IsArray(shape)) {
dest_piece->set_buffer(src_piece->buffer());
} else {
// If the shape is neither an array nor tuple, then it must be
// zero-sized. Otherwise, some memory needs to be allocated for it.
CHECK_EQ(dest_piece->size_bytes(), 0);
}
}
MutableLiteralBase::~MutableLiteralBase() {}
MutableBorrowingLiteral::MutableBorrowingLiteral(
const MutableBorrowingLiteral& literal)
: MutableLiteralBase() {
shape_ = MakeUnique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
}
MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
const MutableBorrowingLiteral& literal) {
shape_ = MakeUnique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
return *this;
}
MutableBorrowingLiteral::MutableBorrowingLiteral(
const MutableLiteralBase& literal)
: MutableLiteralBase() {
shape_ = MakeUnique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
}
MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
: MutableLiteralBase() {
shape_ = MakeUnique<Shape>(literal->shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
}
MutableBorrowingLiteral::MutableBorrowingLiteral(
MutableBorrowingLiteral literal, const ShapeIndex& view_root)
: MutableLiteralBase() {
shape_ = MakeUnique<Shape>(literal.piece(view_root).subshape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
}
MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
const Shape& shape)
: MutableLiteralBase() {
shape_ = MakeUnique<Shape>(shape);
CHECK(LayoutUtil::HasLayout(*shape_));
CHECK(!ShapeUtil::IsTuple(*shape_));
root_piece_ = new Piece();
root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
root_piece_->set_subshape(shape_.get());
}
MutableBorrowingLiteral::~MutableBorrowingLiteral() {
if (root_piece_ != nullptr) {
root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (piece->buffer() != nullptr) {
delete piece->sparse_indices();
}
});
delete root_piece_;
}
}
LiteralSlice::LiteralSlice(const LiteralBase& literal)
: LiteralBase(), root_piece_(&literal.root_piece()) {}
LiteralSlice::LiteralSlice(const LiteralBase& literal,
const ShapeIndex& view_root)
: LiteralBase(), root_piece_(&literal.piece(view_root)) {}
void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
CHECK(ShapeUtil::IsTuple(shape)); CHECK(ShapeUtil::IsTuple(shape));
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
@ -1932,13 +2060,6 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
} }
} }
LiteralSlice::LiteralSlice(const LiteralBase& literal)
: LiteralBase(), root_piece_(&literal.root_piece()) {}
LiteralSlice::LiteralSlice(const LiteralBase& literal,
const ShapeIndex& view_root)
: LiteralBase(), root_piece_(&literal.piece(view_root)) {}
BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
: LiteralBase(), shape_(MakeUnique<Shape>(shape)) { : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
CHECK(ShapeUtil::IsArray(*shape_)); CHECK(ShapeUtil::IsArray(*shape_));

View File

@ -310,9 +310,10 @@ class LiteralBase {
// type of literal itself (0 for numeric types, and false for predicates). // type of literal itself (0 for numeric types, and false for predicates).
// //
// Note: It's an antipattern to use this method then immediately call // Note: It's an antipattern to use this method then immediately call
// Literal::Populate on the result (since that results in zero initialization, // MutableLiteralBase::Populate on the result (since that results in zero
// then reinitialization. Conside if a call to MakeUnique<Literal>(shape), // initialization, then reinitialization. Conside if a call to
// followed by the call to Literal::Populate can be used instead. // MakeUnique<Literal>(shape), followed by the call to
// MutableLiteralBase::Populate can be used instead.
static std::unique_ptr<Literal> CreateFromShape(const Shape& shape); static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
protected: protected:
@ -534,7 +535,7 @@ class LiteralBase {
virtual const Piece& root_piece() const = 0; virtual const Piece& root_piece() const = 0;
// LiteralSlice and Literal must access Pieces of other Literals. // LiteralSlice and Literal must access Pieces of other Literals.
friend class Literal; friend class MutableLiteralBase;
friend class LiteralSlice; friend class LiteralSlice;
friend class BorrowingLiteral; friend class BorrowingLiteral;
@ -545,33 +546,10 @@ class LiteralBase {
tensorflow::gtl::ArraySlice<int64> start_indices) const; tensorflow::gtl::ArraySlice<int64> start_indices) const;
}; };
// Class representing literal values in XLA. // Abstract base class representing a mutable literal in XLA.
// class MutableLiteralBase : public LiteralBase {
// The underlying buffer and shape is always owned by this class.
class Literal : public LiteralBase {
public: public:
Literal() : Literal(ShapeUtil::MakeNil()) {} virtual ~MutableLiteralBase() = 0;
// Create a literal of the given shape. The literal is allocated sufficient
// memory to hold the shape. Memory is uninitialized.
explicit Literal(const Shape& shape);
virtual ~Literal();
// Literals are moveable, but not copyable. To copy a literal use
// Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
// of literals which can be expensive.
Literal(const Literal& other) = delete;
Literal& operator=(const Literal& other) = delete;
Literal(Literal&& other);
// 'allocate_arrays' indicates whether to allocate memory for the arrays in
// the shape. If false, buffer pointers inside of the Literal::Pieces are set
// to nullptr.
Literal(const Shape& shape, bool allocate_arrays);
Literal& operator=(Literal&& other);
// TODO(b/67651157): Remove this accessor. Literal users should not be able to
// mutate the shape as this can produce malformed Literals.
Shape* mutable_shape_do_not_use() { return shape_.get(); }
// Returns a MutableArraySlice view of the array for this literal for the // Returns a MutableArraySlice view of the array for this literal for the
// given NativeT (e.g., float). CHECKs if the subshape of the literal at the // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
@ -587,6 +565,10 @@ class Literal : public LiteralBase {
// is not a sparse array. // is not a sparse array.
SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
// TODO(b/67651157): Remove this accessor. Literal users should not be able to
// mutate the shape as this can produce malformed Literals.
Shape* mutable_shape_do_not_use() { return shape_.get(); }
// Returns a pointer to the underlying buffer holding the array at the given // Returns a pointer to the underlying buffer holding the array at the given
// shape index. CHECKs if the subshape of the literal at the given ShapeIndex // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
// is not array. // is not array.
@ -613,21 +595,6 @@ class Literal : public LiteralBase {
const ShapeIndex& dest_shape_index = {}, const ShapeIndex& dest_shape_index = {},
const ShapeIndex& src_shape_index = {}); const ShapeIndex& src_shape_index = {});
// Returns a vector containing the tuple elements of this Literal as separate
// Literals. This Literal must be tuple-shaped and can be a nested tuple. The
// elements are moved into the new Literals; no data is copied. Upon return
// this Literal is set to a nil shape (empty tuple)
std::vector<Literal> DecomposeTuple();
// Similar to CopyFrom, but with move semantincs. The subshape of this literal
// rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
// (layouts and shapes must match), but need not be arrays. The memory
// allocated in this literal for the subshape at dest_shape_index is
// deallocated, and the respective buffers are replaced with those in
// src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
Status MoveFrom(Literal&& src_literal,
const ShapeIndex& dest_shape_index = {});
// Copies the values from src_literal, starting at src_base shape indexes, // Copies the values from src_literal, starting at src_base shape indexes,
// to this literal, starting at dest_base, where the copy size in each // to this literal, starting at dest_base, where the copy size in each
// dimension is specified by copy_size. // dimension is specified by copy_size.
@ -730,12 +697,7 @@ class Literal : public LiteralBase {
static StatusOr<std::unique_ptr<Literal>> CreateFromProto( static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
const LiteralProto& proto); const LiteralProto& proto);
private: protected:
// Recursively sets the subshapes and buffers of all subpieces rooted at
// 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
// the shape.
void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
// Returns the piece at the given ShapeIndex. // Returns the piece at the given ShapeIndex.
Piece& piece(const ShapeIndex& shape_index) { Piece& piece(const ShapeIndex& shape_index) {
return const_cast<Piece&>(LiteralBase::piece(shape_index)); return const_cast<Piece&>(LiteralBase::piece(shape_index));
@ -783,12 +745,83 @@ class Literal : public LiteralBase {
template <typename NativeT, typename FnType> template <typename NativeT, typename FnType>
Status PopulateInternal(const FnType& generator, bool parallel); Status PopulateInternal(const FnType& generator, bool parallel);
friend class LiteralBase;
friend class MutableBorrowingLiteral;
};
std::ostream& operator<<(std::ostream& out, const Literal& literal);
// The underlying buffer and shape is always owned by this class.
class Literal : public MutableLiteralBase {
public:
Literal() : Literal(ShapeUtil::MakeNil()) {}
// Create a literal of the given shape. The literal is allocated sufficient
// memory to hold the shape. Memory is uninitialized.
explicit Literal(const Shape& shape);
virtual ~Literal();
// Literals are moveable, but not copyable. To copy a literal use
// Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
// of literals which can be expensive.
Literal(const Literal& other) = delete;
Literal& operator=(const Literal& other) = delete;
Literal(Literal&& other);
// 'allocate_arrays' indicates whether to allocate memory for the arrays in
// the shape. If false, buffer pointers inside of the Literal::Pieces are set
// to nullptr.
Literal(const Shape& shape, bool allocate_arrays);
Literal& operator=(Literal&& other);
// Similar to CopyFrom, but with move semantincs. The subshape of this literal
// rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
// (layouts and shapes must match), but need not be arrays. The memory
// allocated in this literal for the subshape at dest_shape_index is
// deallocated, and the respective buffers are replaced with those in
// src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
virtual Status MoveFrom(Literal&& src_literal,
const ShapeIndex& dest_shape_index = {});
// Returns a vector containing the tuple elements of this Literal as separate
// Literals. This Literal must be tuple-shaped and can be a nested tuple. The
// elements are moved into the new Literals; no data is copied. Upon return
// this Literal is set to a nil shape (empty tuple)
std::vector<Literal> DecomposeTuple();
private:
// Deallocate the buffers held by this literal. // Deallocate the buffers held by this literal.
void DeallocateBuffers(); void DeallocateBuffers();
friend class LiteralBase; // Recursively sets the subshapes and buffers of all subpieces rooted at
// 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
// the shape.
void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
};
// The underlying buffer is not owned by this class and is always owned by
// others. The shape is not owned by this class and not mutable.
class MutableBorrowingLiteral : public MutableLiteralBase {
public:
virtual ~MutableBorrowingLiteral();
MutableBorrowingLiteral() : MutableLiteralBase() {}
MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
// Implicit conversion constructors.
MutableBorrowingLiteral(const MutableLiteralBase& literal);
MutableBorrowingLiteral(MutableLiteralBase* literal);
MutableBorrowingLiteral(MutableBorrowingLiteral literal,
const ShapeIndex& view_root);
MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
private:
// Recursively copies the subtree from the `src_piece` at the given child
// index to the `dest_piece`. For buffers only the pointers are copied, but
// not the content.
void CopyPieceSubtree(const Shape& shape, Piece* src_piece,
Piece* dest_piece);
}; };
std::ostream& operator<<(std::ostream& out, const Literal& literal);
// A read-only view of a Literal. A LiteralSlice contains pointers to shape and // A read-only view of a Literal. A LiteralSlice contains pointers to shape and
// literal buffers always owned by others. // literal buffers always owned by others.
@ -831,9 +864,9 @@ class BorrowingLiteral : public LiteralBase {
const Piece& root_piece() const override { return root_piece_; }; const Piece& root_piece() const override { return root_piece_; };
Piece root_piece_; Piece root_piece_;
// Shape of this literal. Stored as unique_ptr so such that the (default) // Shape of this literal. Stored as unique_ptr such that the (default) move
// move construction of this class would be trivially correct: the pointer to // construction of this class would be trivially correct: the pointer to Shape
// Shape root_piece_ stores will still point to the correct address. // root_piece_ stores will still point to the correct address.
std::unique_ptr<Shape> shape_; std::unique_ptr<Shape> shape_;
}; };
@ -886,7 +919,7 @@ tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
} }
template <typename NativeT> template <typename NativeT>
tensorflow::gtl::MutableArraySlice<NativeT> Literal::data( tensorflow::gtl::MutableArraySlice<NativeT> MutableLiteralBase::data(
const ShapeIndex& shape_index) { const ShapeIndex& shape_index) {
return piece(shape_index).data<NativeT>(); return piece(shape_index).data<NativeT>();
} }
@ -904,14 +937,15 @@ inline NativeT LiteralBase::Get(
} }
template <typename NativeT> template <typename NativeT>
inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index, inline void MutableLiteralBase::Set(
tensorflow::gtl::ArraySlice<int64> multi_index,
const ShapeIndex& shape_index, NativeT value) { const ShapeIndex& shape_index, NativeT value) {
return piece(shape_index).Set<NativeT>(multi_index, value); return piece(shape_index).Set<NativeT>(multi_index, value);
} }
template <typename NativeT> template <typename NativeT>
inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index, inline void MutableLiteralBase::Set(
NativeT value) { tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value) {
return root_piece().Set<NativeT>(multi_index, value); return root_piece().Set<NativeT>(multi_index, value);
} }
@ -929,7 +963,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
} }
template <typename NativeT> template <typename NativeT>
void Literal::AppendSparseElement( void MutableLiteralBase::AppendSparseElement(
tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value, tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
const ShapeIndex& shape_index) { const ShapeIndex& shape_index) {
Piece& p = piece(shape_index); Piece& p = piece(shape_index);
@ -959,7 +993,8 @@ void LiteralBase::EachCell(
} }
template <typename NativeT> template <typename NativeT>
inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) { inline void MutableLiteralBase::PopulateR1(
tensorflow::gtl::ArraySlice<NativeT> values) {
CHECK(ShapeUtil::IsArray(shape())); CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
@ -971,7 +1006,7 @@ inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
} }
template <typename NativeT> template <typename NativeT>
void Literal::PopulateR2( void MutableLiteralBase::PopulateR2(
std::initializer_list<std::initializer_list<NativeT>> values) { std::initializer_list<std::initializer_list<NativeT>> values) {
CHECK(ShapeUtil::IsArray(shape())); CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 2); CHECK_EQ(ShapeUtil::Rank(shape()), 2);
@ -996,7 +1031,7 @@ void Literal::PopulateR2(
} }
template <typename NativeT> template <typename NativeT>
void Literal::PopulateFromArray(const Array<NativeT>& values) { void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
CHECK(ShapeUtil::IsArray(shape())); CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(shape().element_type(), CHECK_EQ(shape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>()); primitive_util::NativeToPrimitiveType<NativeT>());
@ -1009,23 +1044,23 @@ void Literal::PopulateFromArray(const Array<NativeT>& values) {
} }
template <typename NativeT> template <typename NativeT>
void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) { void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
PopulateFromArray(values); PopulateFromArray(values);
} }
template <typename NativeT> template <typename NativeT>
void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) { void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
PopulateFromArray(values); PopulateFromArray(values);
} }
template <typename NativeT> template <typename NativeT>
void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) { void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
PopulateFromArray(values); PopulateFromArray(values);
} }
template <typename NativeT> template <typename NativeT>
void Literal::PopulateSparse(SparseIndexArray indices, void MutableLiteralBase::PopulateSparse(
tensorflow::gtl::ArraySlice<NativeT> values, SparseIndexArray indices, tensorflow::gtl::ArraySlice<NativeT> values,
bool sort) { bool sort) {
CHECK(LayoutUtil::IsSparseArray(shape())); CHECK(LayoutUtil::IsSparseArray(shape()));
int rank = ShapeUtil::Rank(shape()); int rank = ShapeUtil::Rank(shape());
@ -1049,7 +1084,8 @@ void Literal::PopulateSparse(SparseIndexArray indices,
} }
template <typename NativeT, typename FnType> template <typename NativeT, typename FnType>
Status Literal::PopulateInternal(const FnType& generator, bool parallel) { Status MutableLiteralBase::PopulateInternal(const FnType& generator,
bool parallel) {
const Shape& this_shape = shape(); const Shape& this_shape = shape();
const int64 rank = ShapeUtil::Rank(this_shape); const int64 rank = ShapeUtil::Rank(this_shape);
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
@ -1092,17 +1128,17 @@ Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
return Status::OK(); return Status::OK();
} }
template <typename NativeT, typename FnType> template <typename NativeT, typename FnType>
Status Literal::Populate(const FnType& generator) { Status MutableLiteralBase::Populate(const FnType& generator) {
return PopulateInternal<NativeT>(generator, /*parallel=*/false); return PopulateInternal<NativeT>(generator, /*parallel=*/false);
} }
template <typename NativeT, typename FnType> template <typename NativeT, typename FnType>
Status Literal::PopulateParallel(const FnType& generator) { Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
return PopulateInternal<NativeT>(generator, /*parallel=*/true); return PopulateInternal<NativeT>(generator, /*parallel=*/true);
} }
template <typename NativeT> template <typename NativeT>
void Literal::PopulateWithValue(NativeT value) { void MutableLiteralBase::PopulateWithValue(NativeT value) {
CHECK(ShapeUtil::IsArray(shape())); CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(shape().element_type(), CHECK_EQ(shape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>()); primitive_util::NativeToPrimitiveType<NativeT>());

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
using tensorflow::strings::StrCat; using tensorflow::strings::StrCat;

View File

@ -570,7 +570,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_lib",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3", "//third_party/eigen3",
@ -613,6 +613,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
], ],
alwayslink = 1, alwayslink = 1,
@ -1384,6 +1385,18 @@ tf_cc_test(
], ],
) )
cc_library(
name = "while_loop_analysis",
srcs = ["while_loop_analysis.cc"],
hdrs = ["while_loop_analysis.h"],
deps = [
":hlo",
":hlo_evaluator",
"//tensorflow/compiler/xla:literal",
"//tensorflow/core:lib",
],
)
cc_library( cc_library(
name = "while_loop_simplifier", name = "while_loop_simplifier",
srcs = ["while_loop_simplifier.cc"], srcs = ["while_loop_simplifier.cc"],
@ -1391,8 +1404,8 @@ cc_library(
deps = [ deps = [
":call_inliner", ":call_inliner",
":hlo", ":hlo",
":hlo_evaluator",
":hlo_pass", ":hlo_pass",
":while_loop_analysis",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],

View File

@ -1803,6 +1803,12 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
} }
Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
// TODO(b/112040122): Most of those optimizations can be done for multi-output
// reduces.
if (ShapeUtil::IsTuple(reduce->shape())) {
return Status::OK();
}
auto arg = reduce->mutable_operand(0); auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1); auto init_value = reduce->mutable_operand(1);
tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions()); tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());

View File

@ -48,11 +48,6 @@ namespace xla {
// compuation. // compuation.
using ObjectFileData = std::vector<char>; using ObjectFileData = std::vector<char>;
// Contains the buffer sizes information needed to allocate buffers to execute
// an ahead-of-time computation. Entries which contain -1 designate a parameter
// which should be skipped over during allocation.
using BufferSizes = std::vector<int64>;
// Abstract superclass describing the result of an ahead-of-time compilation. // Abstract superclass describing the result of an ahead-of-time compilation.
class AotCompilationResult { class AotCompilationResult {
public: public:

View File

@ -54,12 +54,24 @@ cc_library(
alwayslink = True, # Contains per-platform transfer manager registration alwayslink = True, # Contains per-platform transfer manager registration
) )
cc_library(
name = "buffer_info_util",
srcs = ["buffer_info_util.cc"],
hdrs = ["buffer_info_util.h"],
deps = [
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/core:lib",
],
)
cc_library( cc_library(
name = "cpu_compiler", name = "cpu_compiler",
srcs = ["cpu_compiler.cc"], srcs = ["cpu_compiler.cc"],
hdrs = ["cpu_compiler.h"], hdrs = ["cpu_compiler.h"],
deps = [ deps = [
":compiler_functor", ":compiler_functor",
":buffer_info_util",
":conv_canonicalization", ":conv_canonicalization",
":cpu_copy_insertion", ":cpu_copy_insertion",
":cpu_executable", ":cpu_executable",
@ -73,6 +85,7 @@ cc_library(
":ir_emitter", ":ir_emitter",
":parallel_task_assignment", ":parallel_task_assignment",
":simple_orc_jit", ":simple_orc_jit",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",

View File

@ -0,0 +1,57 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
namespace xla {
namespace cpu {
using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo;
std::vector<BufferInfo> CreateBufferInfosFromBufferAssignment(
const BufferAssignment& buffer_assignment) {
std::vector<BufferInfo> buffer_infos;
for (const BufferAllocation& allocation : buffer_assignment.Allocations()) {
if (allocation.is_thread_local()) {
buffer_infos.push_back(BufferInfo::MakeOnStackBuffer(allocation.size()));
} else if (allocation.is_constant()) {
buffer_infos.push_back(BufferInfo::MakeConstant(allocation.size()));
} else if (allocation.is_entry_computation_parameter()) {
buffer_infos.push_back(BufferInfo::MakeEntryParameter(
/*size=*/allocation.size(),
/*param_number=*/allocation.parameter_number()));
} else {
buffer_infos.push_back(BufferInfo::MakeTempBuffer(allocation.size()));
}
}
return buffer_infos;
}
std::vector<int32> CreateArgIndexTableFromBufferInfos(
tensorflow::gtl::ArraySlice<BufferInfo> buffer_infos) {
std::vector<int32> result;
for (int64 i = 0; i < buffer_infos.size(); i++) {
if (buffer_infos[i].is_entry_parameter()) {
if (buffer_infos[i].entry_parameter_number() >= result.size()) {
result.resize(buffer_infos[i].entry_parameter_number() + 1);
}
result[buffer_infos[i].entry_parameter_number()] = i;
}
}
return result;
}
} // namespace cpu
} // namespace xla

View File

@ -0,0 +1,42 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace cpu {
// Creates and returns a list of BufferInfo instances containing relevant
// information from `buffer_assignment`.
std::vector<::tensorflow::cpu_function_runtime::BufferInfo>
CreateBufferInfosFromBufferAssignment(
const BufferAssignment& buffer_assignment);
// Creates and returns a table containing the mapping from entry computation
// parameters to buffer allocation indices.
//
// If this function returns V then entry parameter i has buffer allocation index
// V[i].
std::vector<int32> CreateArgIndexTableFromBufferInfos(
tensorflow::gtl::ArraySlice<::tensorflow::cpu_function_runtime::BufferInfo>
buffer_infos);
} // namespace cpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_

View File

@ -50,6 +50,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
@ -103,6 +104,7 @@ limitations under the License.
namespace xla { namespace xla {
namespace cpu { namespace cpu {
using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo;
CpuAotCompilationOptions::CpuAotCompilationOptions( CpuAotCompilationOptions::CpuAotCompilationOptions(
string triple, string cpu_name, string features, string entry_point_name, string triple, string cpu_name, string features, string entry_point_name,
@ -120,11 +122,11 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
} }
CpuAotCompilationResult::CpuAotCompilationResult( CpuAotCompilationResult::CpuAotCompilationResult(
ObjectFileData object_file_data, BufferSizes buffer_sizes, ObjectFileData object_file_data, std::vector<BufferInfo> buffer_infos,
int64 result_buffer_index, int64 result_buffer_index,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data) std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
: object_file_data_(std::move(object_file_data)), : object_file_data_(std::move(object_file_data)),
buffer_sizes_(std::move(buffer_sizes)), buffer_infos_(std::move(buffer_infos)),
result_buffer_index_(result_buffer_index), result_buffer_index_(result_buffer_index),
hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {} hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
@ -838,39 +840,14 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
ObjectFileData object_file_data(object_file->getBufferStart(), ObjectFileData object_file_data(object_file->getBufferStart(),
object_file->getBufferEnd()); object_file->getBufferEnd());
BufferSizes buffer_sizes; std::vector<BufferInfo> buffer_infos =
for (const BufferAllocation& allocation : assignment->Allocations()) { CreateBufferInfosFromBufferAssignment(*assignment);
// Callers don't need to allocate anything for thread-local temporary
// buffers. They are lowered to allocas.
if (allocation.is_thread_local()) {
buffer_sizes.push_back(-1);
continue;
}
// Callers don't need to allocate anything for constant buffers. They are
// lowered to globals.
if (allocation.is_constant()) {
buffer_sizes.push_back(-1);
continue;
}
// Callers don't need to allocate anything for entry computation buffers,
// but they do need to stash the pointer to the entry computation buffer
// in the temp buffer table. See the comment on
// XlaCompiledCpuFunction::StaticData::temp_sizes.
if (allocation.is_entry_computation_parameter()) {
buffer_sizes.push_back(-allocation.parameter_number() - 2);
continue;
}
buffer_sizes.push_back(allocation.size());
}
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment->GetUniqueTopLevelOutputSlice()); assignment->GetUniqueTopLevelOutputSlice());
results.emplace_back(MakeUnique<CpuAotCompilationResult>( results.emplace_back(MakeUnique<CpuAotCompilationResult>(
std::move(object_file_data), std::move(buffer_sizes), std::move(object_file_data), std::move(buffer_infos),
result_slice.index(), std::move(hlo_profile_printer_data))); result_slice.index(), std::move(hlo_profile_printer_data)));
} }

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory> #include <memory>
#include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h"
@ -78,7 +79,8 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
class CpuAotCompilationResult : public AotCompilationResult { class CpuAotCompilationResult : public AotCompilationResult {
public: public:
CpuAotCompilationResult( CpuAotCompilationResult(
ObjectFileData object_file_data, BufferSizes buffer_sizes, ObjectFileData object_file_data,
std::vector<::tensorflow::cpu_function_runtime::BufferInfo> buffer_infos,
int64 result_buffer_index, int64 result_buffer_index,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data); std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data);
~CpuAotCompilationResult(); ~CpuAotCompilationResult();
@ -88,17 +90,20 @@ class CpuAotCompilationResult : public AotCompilationResult {
} }
const ObjectFileData& object_file_data() const { return object_file_data_; } const ObjectFileData& object_file_data() const { return object_file_data_; }
const BufferSizes& buffer_sizes() const { return buffer_sizes_; } const std::vector<::tensorflow::cpu_function_runtime::BufferInfo>&
buffer_infos() const {
return buffer_infos_;
}
int64 result_buffer_index() const { return result_buffer_index_; } int64 result_buffer_index() const { return result_buffer_index_; }
private: private:
// Contains the compiled computation: an object file. // Contains the compiled computation: an object file.
const ObjectFileData object_file_data_; const ObjectFileData object_file_data_;
// The list of buffer sizes which should be allocated in order to execute the // A list of BufferInfo objects describing the buffers used by the XLA
// compiled computation. These buffers are used for temporary buffers used // computation.
// ephemerally during computation as well as the output result. const std::vector<::tensorflow::cpu_function_runtime::BufferInfo>
const BufferSizes buffer_sizes_; buffer_infos_;
// Contains which buffer index into |buffer_sizes| was designated to the // Contains which buffer index into |buffer_sizes| was designated to the
// result of the computation. This buffer should be passed into the output // result of the computation. This buffer should be passed into the output

View File

@ -173,7 +173,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor,
Status CpuTransferManager::TransferLiteralFromOutfeed( Status CpuTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* executor, const Shape& literal_shape, se::StreamExecutor* executor, const Shape& literal_shape,
Literal* literal) { MutableBorrowingLiteral literal) {
if (!ShapeUtil::IsTuple(literal_shape)) { if (!ShapeUtil::IsTuple(literal_shape)) {
int64 size = GetByteSizeRequirement(literal_shape); int64 size = GetByteSizeRequirement(literal_shape);
// Note: OSS build didn't like implicit conversion from // Note: OSS build didn't like implicit conversion from
@ -181,18 +181,16 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
tensorflow::gtl::ArraySlice<int64> dimensions( tensorflow::gtl::ArraySlice<int64> dimensions(
tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()), tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()),
literal_shape.dimensions().size()); literal_shape.dimensions().size());
*literal = std::move(*LiteralUtil::CreateFromDimensions( TF_ASSIGN_OR_RETURN(
literal_shape.element_type(), dimensions)); Shape received_shape,
TF_ASSIGN_OR_RETURN(Shape received_shape, TransferArrayBufferFromOutfeed(executor, literal.untyped_data(), size));
TransferArrayBufferFromOutfeed( TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape()))
executor, literal->untyped_data(), size));
TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape()))
<< "Shape received from outfeed " << "Shape received from outfeed "
<< ShapeUtil::HumanString(received_shape) << ShapeUtil::HumanString(received_shape)
<< " did not match the shape that was requested for outfeed: " << " did not match the shape that was requested for outfeed: "
<< ShapeUtil::HumanString(literal_shape); << ShapeUtil::HumanString(literal_shape);
TF_RET_CHECK(size == GetByteSizeRequirement(received_shape)); TF_RET_CHECK(size == GetByteSizeRequirement(received_shape));
*literal->mutable_shape_do_not_use() = received_shape; *literal.mutable_shape_do_not_use() = received_shape;
return Status::OK(); return Status::OK();
} }
@ -201,22 +199,12 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
"Nested tuple outfeeds are not yet implemented on CPU."); "Nested tuple outfeeds are not yet implemented on CPU.");
} }
std::vector<std::unique_ptr<Literal>> elements;
std::vector<std::pair<void*, int64>> buffer_data; std::vector<std::pair<void*, int64>> buffer_data;
for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
const Shape& tuple_element_shape = const Shape& tuple_element_shape =
ShapeUtil::GetTupleElementShape(literal_shape, i); ShapeUtil::GetTupleElementShape(literal_shape, i);
// Note: OSS build didn't like implicit conversion from
// literal_shape.dimensions() to the array slice on 2017-07-10.
tensorflow::gtl::ArraySlice<int64> dimensions(
tensorflow::bit_cast<const int64*>(
tuple_element_shape.dimensions().data()),
tuple_element_shape.dimensions().size());
auto empty = LiteralUtil::CreateFromDimensions(
tuple_element_shape.element_type(), dimensions);
int64 size = GetByteSizeRequirement(tuple_element_shape); int64 size = GetByteSizeRequirement(tuple_element_shape);
buffer_data.push_back({empty->untyped_data(), size}); buffer_data.push_back({literal.untyped_data({i}), size});
elements.push_back(std::move(empty));
} }
TF_ASSIGN_OR_RETURN(Shape received_shape, TF_ASSIGN_OR_RETURN(Shape received_shape,
@ -230,11 +218,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
TF_RET_CHECK(GetByteSizeRequirement(literal_shape) == TF_RET_CHECK(GetByteSizeRequirement(literal_shape) ==
GetByteSizeRequirement(received_shape)); GetByteSizeRequirement(received_shape));
for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal_shape));
*elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i);
}
*literal = std::move(*LiteralUtil::MakeTupleOwned(std::move(elements)));
TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape));
return Status::OK(); return Status::OK();
} }

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h"
@ -41,7 +42,7 @@ class CpuTransferManager : public GenericTransferManager {
const LiteralSlice& literal) override; const LiteralSlice& literal) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape, const Shape& literal_shape,
Literal* literal) override; MutableBorrowingLiteral literal) override;
private: private:
Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,

View File

@ -30,47 +30,6 @@ limitations under the License.
namespace xla { namespace xla {
namespace cpu { namespace cpu {
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
switch (op->opcode()) {
case HloOpcode::kTanh: {
PrimitiveType element_type = op->shape().element_type();
bool cast_result_to_fp16 = false;
string function_name;
switch (element_type) {
case F16:
cast_result_to_fp16 = true;
operand_value = b_->CreateFPCast(operand_value, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "tanhf";
break;
case F64:
function_name = "tanh";
break;
default:
return Unimplemented("tanh");
}
// Create a function declaration.
llvm::Function* function =
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
llvm_ir::AsStringRef(function_name), operand_value->getType(),
operand_value->getType()));
function->setCallingConv(llvm::CallingConv::C);
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
llvm::Value* result = b_->CreateCall(function, operand_value);
if (cast_result_to_fp16) {
result = b_->CreateFPCast(result, b_->getHalfTy());
}
return result;
}
default:
return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value);
}
}
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2( StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
string function_name; string function_name;
@ -106,6 +65,39 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
return result; return result;
} }
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(
PrimitiveType prim_type, llvm::Value* value) const {
bool cast_result_to_fp16 = false;
string function_name;
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
value = b_->CreateFPCast(value, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "tanhf";
break;
case F64:
function_name = "tanh";
break;
default:
return Unimplemented("tanh");
}
// Create a function declaration.
llvm::Function* function = llvm::cast<llvm::Function>(
module_->getOrInsertFunction(llvm_ir::AsStringRef(function_name),
value->getType(), value->getType()));
function->setCallingConv(llvm::CallingConv::C);
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
llvm::Value* result = b_->CreateCall(function, value);
if (cast_result_to_fp16) {
result = b_->CreateFPCast(result, b_->getHalfTy());
}
return result;
}
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo, const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator) const { const HloToElementGeneratorMap& operand_to_generator) const {

View File

@ -39,10 +39,10 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
const HloToElementGeneratorMap& operand_to_generator) const override; const HloToElementGeneratorMap& operand_to_generator) const override;
protected: protected:
StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) const override; llvm::Value* rhs) const override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) const override;
IrEmitter* ir_emitter_; IrEmitter* ir_emitter_;
}; };

View File

@ -1756,6 +1756,10 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
} }
Status IrEmitter::HandleReduce(HloInstruction* reduce) { Status IrEmitter::HandleReduce(HloInstruction* reduce) {
// TODO(b/112040122): Support variadic reduce.
if (!ShapeUtil::IsArray(reduce->shape())) {
return Unimplemented("Variadic reduce is not supported on CPU");
}
auto arg = reduce->mutable_operand(0); auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1); auto init_value = reduce->mutable_operand(1);
gtl::ArraySlice<int64> dimensions(reduce->dimensions()); gtl::ArraySlice<int64> dimensions(reduce->dimensions());

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) #if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY)
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "third_party/intel_mkl_ml/include/mkl_cblas.h" #include "third_party/intel_mkl_ml/include/mkl_cblas.h"
#include "third_party/intel_mkl_ml/include/mkl_service.h" #include "third_party/intel_mkl_ml/include/mkl_service.h"

View File

@ -106,6 +106,7 @@ class DfsHloVisitorBase {
virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0;
virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0;
virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
virtual Status HandleCompare(HloInstructionPtr hlo) { virtual Status HandleCompare(HloInstructionPtr hlo) {
return HandleElementwiseBinary(hlo); return HandleElementwiseBinary(hlo);
} }

View File

@ -94,6 +94,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleCrossReplicaSum(HloInstructionPtr crs) override { Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
return DefaultAction(crs); return DefaultAction(crs);
} }
Status HandleAllToAll(HloInstructionPtr crs) override {
return DefaultAction(crs);
}
Status HandleRng(HloInstructionPtr random) override { Status HandleRng(HloInstructionPtr random) override {
return DefaultAction(random); return DefaultAction(random);
} }

View File

@ -431,6 +431,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
return EmitCos(op->shape().element_type(), operand_value); return EmitCos(op->shape().element_type(), operand_value);
case HloOpcode::kSin: case HloOpcode::kSin:
return EmitSin(op->shape().element_type(), operand_value); return EmitSin(op->shape().element_type(), operand_value);
case HloOpcode::kTanh:
return EmitTanh(op->shape().element_type(), operand_value);
case HloOpcode::kFloor: case HloOpcode::kFloor:
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
{operand_value}, {operand_value},
@ -1060,6 +1062,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
return Unimplemented("atan2"); return Unimplemented("atan2");
} }
StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
llvm::Value* value) const {
return Unimplemented("tanh");
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision( StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
const HloInstruction* hlo, llvm::Value* x) const { const HloInstruction* hlo, llvm::Value* x) const {
if (hlo->operand(0)->shape().element_type() != F32) { if (hlo->operand(0)->shape().element_type() != F32) {
@ -1239,13 +1246,23 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
// Convert raw integer to float in range [0, 1) if the element is a float. // Convert raw integer to float in range [0, 1) if the element is a float.
llvm::Value* elem_value = raw_value; llvm::Value* elem_value = raw_value;
if (elem_ir_ty->isFloatingPointTy()) { if (elem_ir_ty->isFloatingPointTy()) {
elem_value = b_->CreateUIToFP(elem_value, elem_ir_ty);
unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits(); unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits();
CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64); CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64);
// Perform the division using the float type with the same number of bits
// as the raw value to avoid overflow.
if (raw_value_size_in_bits == 32) {
elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy());
elem_value = b_->CreateFDiv( elem_value = b_->CreateFDiv(
elem_value, elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32)));
llvm::ConstantFP::get(elem_ir_ty, } else {
raw_value_size_in_bits == 64 ? 0x1p64 : 0x1p32)); elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy());
elem_value = b_->CreateFDiv(
elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64)));
}
if (elem_ir_ty != elem_value->getType()) {
elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty);
}
} }
// Convert the value for the requested distribution. // Convert the value for the requested distribution.
@ -1302,6 +1319,7 @@ int32 GetNumberOfElementsPerPhiloxRngSample(PrimitiveType elem_prim_ty) {
case F16: case F16:
return 4; return 4;
case U64: case U64:
case S64:
case F64: case F64:
return 2; return 2;
default: default:

View File

@ -122,6 +122,9 @@ class ElementalIrEmitter {
llvm::Value* lhs, llvm::Value* lhs,
llvm::Value* rhs) const; llvm::Value* rhs) const;
virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) const;
virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo, virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
llvm::Value* x) const; llvm::Value* x) const;

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
@ -60,17 +59,19 @@ Status GenericTransferManager::WriteSingleTupleIndexTable(
void GenericTransferManager::TransferLiteralFromDevice( void GenericTransferManager::TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer, se::Stream* stream, const ShapedBuffer& device_buffer,
std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) { MutableBorrowingLiteral literal, std::function<void(Status)> done) {
Status status = stream->BlockHostUntilDone(); Status status = stream->BlockHostUntilDone();
if (!status.ok()) { if (!status.ok()) {
return done(status); return done(status);
} }
done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer));
done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer,
literal));
} }
StatusOr<std::unique_ptr<Literal>> Status GenericTransferManager::TransferLiteralFromDeviceInternal(
GenericTransferManager::TransferLiteralFromDeviceInternal( se::StreamExecutor* executor, const ShapedBuffer& device_buffer,
se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { MutableBorrowingLiteral literal) {
VLOG(2) << "transferring literal from device ordinal " VLOG(2) << "transferring literal from device ordinal "
<< executor->device_ordinal() << "; device buffer: " << device_buffer; << executor->device_ordinal() << "; device buffer: " << device_buffer;
TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
@ -80,9 +81,6 @@ GenericTransferManager::TransferLiteralFromDeviceInternal(
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
device_buffer.on_host_shape())); device_buffer.on_host_shape()));
std::unique_ptr<Literal> literal =
Literal::CreateFromShape(device_buffer.on_host_shape());
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_host_shape(), device_buffer.on_host_shape(),
[&](const Shape& subshape, const ShapeIndex& index) -> Status { [&](const Shape& subshape, const ShapeIndex& index) -> Status {
@ -91,12 +89,12 @@ GenericTransferManager::TransferLiteralFromDeviceInternal(
/*source=*/device_buffer.buffer(index), /*source=*/device_buffer.buffer(index),
/*size=*/GetByteSizeRequirement(subshape), /*size=*/GetByteSizeRequirement(subshape),
/*destination=*/ /*destination=*/
literal->untyped_data(index))); literal.untyped_data(index)));
} }
return Status::OK(); return Status::OK();
})); }));
return std::move(literal); return Status::OK();
} }
Status GenericTransferManager::TransferLiteralToDeviceAsync( Status GenericTransferManager::TransferLiteralToDeviceAsync(
@ -160,7 +158,7 @@ Status GenericTransferManager::TransferLiteralToInfeed(
Status GenericTransferManager::TransferLiteralFromOutfeed( Status GenericTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* executor, const Shape& literal_shape, se::StreamExecutor* executor, const Shape& literal_shape,
Literal* literal) { MutableBorrowingLiteral literal) {
return Unimplemented("Generic transfer from Outfeed"); return Unimplemented("Generic transfer from Outfeed");
} }

View File

@ -19,7 +19,6 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -41,9 +40,10 @@ class GenericTransferManager : public TransferManager {
se::Platform::Id PlatformId() const override; se::Platform::Id PlatformId() const override;
void TransferLiteralFromDevice( void TransferLiteralFromDevice(se::Stream* stream,
se::Stream* stream, const ShapedBuffer& device_buffer, const ShapedBuffer& device_buffer,
std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) override; MutableBorrowingLiteral literal,
std::function<void(Status)> done) override;
Status TransferLiteralToDeviceAsync( Status TransferLiteralToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal, se::Stream* stream, const LiteralSlice& literal,
@ -53,7 +53,7 @@ class GenericTransferManager : public TransferManager {
const LiteralSlice& literal) override; const LiteralSlice& literal) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape, const Shape& literal_shape,
Literal* literal) override; MutableBorrowingLiteral literal) override;
Status ResetDevices( Status ResetDevices(
tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override; tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override;
@ -67,8 +67,9 @@ class GenericTransferManager : public TransferManager {
const Shape& shape, se::DeviceMemoryBase* region) override; const Shape& shape, se::DeviceMemoryBase* region) override;
private: private:
StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDeviceInternal( Status TransferLiteralFromDeviceInternal(se::StreamExecutor* executor,
se::StreamExecutor* executor, const ShapedBuffer& device_buffer); const ShapedBuffer& device_buffer,
MutableBorrowingLiteral literal);
// The platform this transfer manager targets. // The platform this transfer manager targets.
const se::Platform::Id platform_id_; const se::Platform::Id platform_id_;

View File

@ -153,7 +153,6 @@ cc_library(
":ir_emission_utils", ":ir_emission_utils",
":parallel_loop_emitter", ":parallel_loop_emitter",
":partition_assignment", ":partition_assignment",
":while_transformer",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
@ -166,6 +165,7 @@ cc_library(
"//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service:while_loop_analysis",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
@ -655,7 +655,6 @@ cc_library(
"//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/compiler/xla/service:while_loop_constant_sinking", "//tensorflow/compiler/xla/service:while_loop_constant_sinking",
"//tensorflow/compiler/xla/service:while_loop_invariant_code_motion",
"//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination",
"//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter",
@ -788,32 +787,17 @@ tf_cc_test(
], ],
) )
cc_library(
name = "while_transformer",
srcs = ["while_transformer.cc"],
hdrs = ["while_transformer.h"],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
],
)
tf_cc_test( tf_cc_test(
name = "while_transformer_test", name = "while_transformer_test",
srcs = ["while_transformer_test.cc"], srcs = ["while_transformer_test.cc"],
deps = [ deps = [
":instruction_fusion", ":instruction_fusion",
":while_transformer",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:while_loop_analysis",
"//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test", "//tensorflow/core:test",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla { namespace xla {
namespace gpu { namespace gpu {
@ -137,6 +138,28 @@ string NumBytesToString(int64 bytes) {
tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)");
} }
// Acquires a process-global lock on the device pointed to by the given
// StreamExecutor.
//
// This is used to prevent other XLA instances from trying to autotune on this
// device while we're using it.
tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
// se::Platform*s are global singletons guaranteed to live forever.
static auto* mutexes =
new std::map<std::pair<const se::Platform*, /*device_ordinal*/ int64>,
tensorflow::mutex>();
tensorflow::mutex_lock global_lock(mu);
auto it = mutexes
->emplace(std::piecewise_construct,
std::make_tuple(stream_exec->platform(),
stream_exec->device_ordinal()),
std::make_tuple())
.first;
return tensorflow::mutex_lock{it->second};
}
} // anonymous namespace } // anonymous namespace
// We could have caching here so that we don't redo this work for two identical // We could have caching here so that we don't redo this work for two identical
@ -155,6 +178,13 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window, const Shape& output_shape, const Window& window,
const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) { const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) {
// Don't run this function concurrently on the same GPU.
//
// This is a bit of a hack and doesn't protect us against arbitrary concurrent
// use of a GPU, but it's sufficient to let us compile two HLO modules
// concurrently and then run them sequentially.
tensorflow::mutex_lock lock = LockGpu(stream_exec_);
// Create a stream for us to do our work on. // Create a stream for us to do our work on.
se::Stream stream{stream_exec_}; se::Stream stream{stream_exec_};
stream.Init(); stream.Init();

View File

@ -272,27 +272,18 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
prim_type); prim_type);
} }
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp( StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
const HloInstruction* op, llvm::Value* operand_value) const { PrimitiveType prim_type, llvm::Value* value) const {
PrimitiveType input_type = op->operand(0)->shape().element_type();
PrimitiveType output_type = op->shape().element_type();
switch (op->opcode()) {
case HloOpcode::kTanh:
// If we don't care much about precision, emit a fast approximation of // If we don't care much about precision, emit a fast approximation of
// tanh. // tanh.
if (hlo_module_config_.debug_options().xla_enable_fast_math()) { if (hlo_module_config_.debug_options().xla_enable_fast_math()) {
// Upcast F16 to F32 if necessary. // Upcast F16 to F32 if necessary.
llvm::Type* type = llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
input_type == F16 ? b_->getFloatTy() : operand_value->getType(); llvm::Value* input = b_->CreateFPCast(value, type);
llvm::Value* input = b_->CreateFPCast(operand_value, type);
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
return b_->CreateFPCast(fast_tanh, operand_value->getType()); return b_->CreateFPCast(fast_tanh, value->getType());
}
return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type},
output_type);
default:
return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value);
} }
return EmitLibdeviceMathCall("__nv_tanh", {value}, {prim_type}, prim_type);
} }
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
@ -445,6 +436,8 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
return b_->CreateLoad(accum_ptr); return b_->CreateLoad(accum_ptr);
}; };
case HloOpcode::kReduce: case HloOpcode::kReduce:
// TODO(b/112040122): This should be supported.
CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce";
return [=, &operand_to_generator]( return [=, &operand_to_generator](
const IrArray::Index& output_index) -> StatusOr<llvm::Value*> { const IrArray::Index& output_index) -> StatusOr<llvm::Value*> {
const HloInstruction* operand = hlo->operand(0); const HloInstruction* operand = hlo->operand(0);

View File

@ -51,9 +51,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
const HloToElementGeneratorMap& operand_to_generator) const override; const HloToElementGeneratorMap& operand_to_generator) const override;
protected: protected:
StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;
StatusOr<llvm::Value*> EmitFloatBinaryOp( StatusOr<llvm::Value*> EmitFloatBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value, const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const override; llvm::Value* rhs_value) const override;
@ -85,6 +82,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) const override; llvm::Value* rhs) const override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) const override;
llvm::Value* EmitThreadId() const override; llvm::Value* EmitThreadId() const override;
private: private:

View File

@ -52,12 +52,12 @@ class GemmThunk : public Thunk {
se::Stream* stream, se::Stream* stream,
HloExecutionProfiler* profiler) override; HloExecutionProfiler* profiler) override;
// Returns true if we'll perform autotuning if run on the given stream. If bool WillAutotuneKernel(se::Stream* stream) override {
// so, we want the GPU to be quiescent during autotuning, so as not to // We will autotune this kernel if we don't already have a autotune result
// introduce noise in our results. // for the stream device.
bool ShouldHaltAllActivityBeforeRunning(se::Stream* stream) override { return autotune_results_.find(
return autotune_results_.count( stream->parent()->GetDeviceDescription().name()) ==
stream->parent()->GetDeviceDescription().name()) != 0; autotune_results_.end();
} }
private: private:
@ -75,6 +75,8 @@ class GemmThunk : public Thunk {
// results. The map's value is the best algorithm we've found for this thunk // results. The map's value is the best algorithm we've found for this thunk
// on this device, or an error if none of the algorithms worked and we should // on this device, or an error if none of the algorithms worked and we should
// use the regular gemm without an algorithm. // use the regular gemm without an algorithm.
//
// TODO(b/112415150): Make this thread safe.
std::unordered_map<string, StatusOr<se::blas::AlgorithmType>> std::unordered_map<string, StatusOr<se::blas::AlgorithmType>>
autotune_results_; autotune_results_;
}; };

View File

@ -131,9 +131,10 @@ Status GpuExecutable::ExecuteThunks(
stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get());
} }
// If this thunk requests it, wait for all currently-executing thunks to // If this thunk is about to autotune then wait for all currently executing
// finish. This is useful e.g. if the thunk is about to perform autotuning. // thunks to finish. This reduces noise and thus the probability of
if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { // choosing a suboptimal algorithm.
if (thunk->WillAutotuneKernel(stream)) {
TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone());
} }

Some files were not shown because too many files have changed in this diff Show More