Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Avijit 2018-08-12 16:21:41 -07:00
commit 9523a98466
2065 changed files with 79800 additions and 17381 deletions

View File

@ -22,6 +22,8 @@ organization for the purposes of conducting machine learning and deep neural
networks research. The system is general enough to be applicable in a wide
variety of other domains, as well.
TensorFlow provides stable Python API and C APIs as well as without API backwards compatibility guarantee like C++, Go, Java, JavaScript and Swift.
Keep up to date with release announcements and security updates by
subscribing to
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
@ -81,13 +83,13 @@ The TensorFlow project strives to abide by generally accepted best practices in
| Build Type | Status | Artifacts |
| --- | --- | --- |
| **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) |
| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
| **Linux XLA** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.png) | TBA |
| **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) |
| **Windows CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.png) | [pypi](https://pypi.org/project/tf-nightly/) |
| **Windows GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
| **Android** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.png) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) |
| **Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [pypi](https://pypi.org/project/tf-nightly/) |
| **Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
| **Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA |
| **MacOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [pypi](https://pypi.org/project/tf-nightly/) |
| **Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [pypi](https://pypi.org/project/tf-nightly/) |
| **Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
| **Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) |
### Community Supported Builds
@ -97,17 +99,20 @@ The TensorFlow project strives to abide by generally accepted best practices in
| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA |
| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA |
| **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA |
| **Linux CPU with Intel® MKL-DNN®** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | TBA |
| **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) |
| **Linux CPU with Intel® MKL-DNN** Python 2.7<br> **Linux CPU with Intel® MKL-DNN** Python 3.5<br> **Linux CPU with Intel® MKL-DNN** Python 3.6| ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)|[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)<br>[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)<br>[1.9.0 py3.6](https://storage.cloud.google.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) |
## For more information
* [Tensorflow Blog](https://medium.com/tensorflow)
* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
* [Tensorflow Twitter](https://twitter.com/tensorflow)
* [TensorFlow Website](https://www.tensorflow.org)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate.

View File

@ -19,7 +19,7 @@
* `tf.data`:
* `tf.contrib.data.group_by_reducer()` is now available via the public API.
* `tf.contrib.data.choose_from_datasets()` is now available via the public API.
* Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`.
* Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating `tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`.
* `tf.estimator`:
* `Estimator`s now use custom savers included in `EstimatorSpec` scaffolds for saving SavedModels during export.
* `EstimatorSpec` will now add a default prediction output for export if no `export_output` is provided, eliminating the need to explicitly include a `PredictOutput` object in the `model_fn` for simple use-cases.

View File

@ -451,11 +451,6 @@ filegroup(
),
)
filegroup(
name = "docs_src",
data = glob(["docs_src/**/*.md"]),
)
cc_library(
name = "grpc",
deps = select({
@ -599,6 +594,7 @@ exports_files(
gen_api_init_files(
name = "tensorflow_python_api_gen",
srcs = ["api_template.__init__.py"],
api_version = 1,
root_init_template = "api_template.__init__.py",
)

View File

@ -1619,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
TF_DeleteFunction(func1);
}
// This test only works when the TF build includes XLA compiler. One way to set
// this up is via bazel build option "--define with_xla_support=true".
//
// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
// something like TENSORFLOW_CAPI_USE_XLA.
#ifdef TENSORFLOW_EAGER_USE_XLA
TEST_F(CApiFunctionTest, StatelessIf_XLA) {
TF_Function* func;
const std::string funcName = "BranchFunc";
DefineFunction(funcName.c_str(), &func);
TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* feed = Placeholder(host_graph_, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_OperationDescription* desc =
TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
TF_AddInput(desc, {true_cond, 0});
TF_Output inputs[] = {{feed, 0}};
TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_SetAttrType(desc, "Tcond", TF_BOOL);
TF_DataType inputType = TF_INT32;
TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
TF_SetDevice(desc, "/device:XLA_CPU:0");
auto op = TF_FinishOperation(desc, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
ASSERT_NE(op, nullptr);
// Create a session for this graph.
CSession csession(host_graph_, s_, /*use_XLA*/ true);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
// Run the graph.
csession.SetInputs({{feed, Int32Tensor(17)}});
csession.SetOutputs({op});
csession.Run(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Tensor* out = csession.output_tensor(0);
ASSERT_TRUE(out != nullptr);
EXPECT_EQ(TF_INT32, TF_TensorType(out));
EXPECT_EQ(0, TF_NumDims(out)); // scalar
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
int32* output_contents = static_cast<int32*>(TF_TensorData(out));
EXPECT_EQ(-17, *output_contents);
// Clean up
csession.CloseAndDelete(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_DeleteFunction(func);
}
#endif // TENSORFLOW_EAGER_USE_XLA
} // namespace
} // namespace tensorflow

View File

@ -26,6 +26,10 @@ limitations under the License.
using tensorflow::GraphDef;
using tensorflow::NodeDef;
static void BoolDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<bool*>(data);
}
static void Int32Deallocator(void* data, size_t, void* arg) {
delete[] static_cast<int32_t*>(data);
}
@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
TF_Tensor* BoolTensor(bool v) {
const int num_bytes = sizeof(bool);
bool* values = new bool[1];
values[0] = v;
return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator,
nullptr);
}
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) {
@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
return op;
}
TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
const char* name) {
unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor);
return Const(tensor.get(), graph, s, name);
}
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name) {
unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);

View File

@ -31,6 +31,8 @@ using ::tensorflow::string;
typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
unique_tensor_ptr;
TF_Tensor* BoolTensor(int32_t v);
// Create a tensor with values of type TF_INT8 provided by `values`.
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
const char* name = "const");
TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar");
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar");

View File

@ -110,7 +110,7 @@ tensorflow::Status GetAllRemoteDevices(
tensorflow::Status CreateRemoteContexts(
const std::vector<string>& remote_workers, int64 rendezvous_id,
const tensorflow::ServerDef& server_def,
int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
for (int i = 0; i < remote_workers.size(); i++) {
@ -129,6 +129,7 @@ tensorflow::Status CreateRemoteContexts(
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
auto* eager_client = remote_eager_workers->GetClient(remote_worker);
if (eager_client == nullptr) {
return tensorflow::errors::Internal(
@ -151,7 +152,8 @@ tensorflow::Status CreateRemoteContexts(
}
tensorflow::Status UpdateTFE_ContextWithServerDef(
const tensorflow::ServerDef& server_def, TFE_Context* ctx) {
int keep_alive_secs, const tensorflow::ServerDef& server_def,
TFE_Context* ctx) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
@ -202,8 +204,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, rendezvous_id, server_def, remote_eager_workers.get(),
ctx->context.Async(), &remote_contexts));
remote_workers, rendezvous_id, keep_alive_secs, server_def,
remote_eager_workers.get(), ctx->context.Async(), &remote_contexts));
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
@ -222,9 +224,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
auto* device_mgr = grpc_server->worker_env()->device_mgr;
ctx->context.InitializeRemote(
std::move(server), std::move(remote_eager_workers),
std::move(remote_device_mgr), remote_contexts, r, device_mgr);
ctx->context.InitializeRemote(std::move(server),
std::move(remote_eager_workers),
std::move(remote_device_mgr), remote_contexts,
r, device_mgr, keep_alive_secs);
return tensorflow::Status::OK();
#undef LOG_AND_RETURN_IF_ERROR
@ -288,6 +291,7 @@ void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
// Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
int keep_alive_secs,
const void* proto,
size_t proto_len,
TF_Status* status) {
@ -297,7 +301,8 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
"Invalid tensorflow.ServerDef protocol buffer");
return;
}
status->status = UpdateTFE_ContextWithServerDef(server_def, ctx);
status->status =
UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx);
}
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
@ -719,6 +724,10 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
}
} // namespace
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); }
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); }
namespace tensorflow {
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,

View File

@ -124,6 +124,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
// If the following is set, all servers identified by the
// ServerDef must be up when the context is created.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
int keep_alive_secs,
const void* proto,
size_t proto_len,
TF_Status* status);
@ -380,6 +381,16 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
TF_Buffer* buf,
TF_Status* status);
// Some TF ops need a step container to be set to limit the lifetime of some
// resources (mostly TensorArray and Stack, used in while loop gradients in
// graph mode). Calling this on a context tells it to start a step.
TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx);
// Ends a step. When there is no active step (that is, every started step has
// been ended) step containers will be cleared. Note: it is not safe to call
// TFE_ContextEndStep while ops which rely on the step container may be running.
TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -151,7 +151,7 @@ void TestRemoteExecute(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
@ -239,7 +239,7 @@ void TestRemoteExecuteSilentCopies(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
@ -371,7 +371,7 @@ void TestRemoteExecuteChangeServerDef(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char remote_device_name[] =
@ -397,7 +397,7 @@ void TestRemoteExecuteChangeServerDef(bool async) {
ASSERT_TRUE(s.ok()) << s.error_message();
ASSERT_TRUE(worker_server->Start().ok());
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Create a new tensor_handle.

View File

@ -379,9 +379,11 @@ tf_cc_test(
srcs = ["gradients/math_grad_test.cc"],
deps = [
":cc_ops",
":client_session",
":grad_op_registry",
":grad_testutil",
":gradient_checker",
":gradients",
":math_grad",
":testutil",
"//tensorflow/core:lib_internal",

View File

@ -120,6 +120,24 @@ Status SplitGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Split", SplitGrad);
Status FillGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// y = fill(fill_shape, x)
// No gradient returned for the fill_shape argument.
grad_outputs->push_back(NoGradient());
// The gradient for x (which must be a scalar) is just the sum of
// all the gradients from the shape it fills.
// We use ReduceSum to implement this, which needs an argument providing
// the indices of all the dimensions of the incoming gradient.
// grad(x) = reduce_sum(grad(y), [0..rank(grad(y))])
auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]),
Const(scope, 1));
grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims));
return scope.status();
}
REGISTER_GRADIENT_OP("Fill", FillGrad);
Status DiagGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {

View File

@ -108,6 +108,14 @@ TEST_F(ArrayGradTest, SplitGrad) {
RunTest({x}, {x_shape}, y.output, {y_shape, y_shape});
}
TEST_F(ArrayGradTest, FillGrad) {
TensorShape x_shape({});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
TensorShape y_shape({2, 5, 3});
auto y = Fill(scope_, {2, 5, 3}, x);
RunTest(x, x_shape, y, y_shape);
}
TEST_F(ArrayGradTest, DiagGrad) {
TensorShape x_shape({5, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));

View File

@ -441,6 +441,22 @@ Status RealDivGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
Status UnsafeDivGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto x_1 = ConjugateHelper(scope, op.input(0));
auto x_2 = ConjugateHelper(scope, op.input(1));
// y = x_1 / x_2
// dy/dx_1 = 1/x_2
// dy/dx_2 = -x_1/x_2^2
auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2);
auto gx_2 =
Mul(scope, grad_inputs[0],
UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2));
return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
}
REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad);
Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
@ -1007,6 +1023,26 @@ Status ProdGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Prod", ProdGrad);
Status SegmentSumGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// The SegmentSum operation sums segments of the Tensor that have the same
// index in the segment_ids parameter.
// i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1]
// will produce [2 + 3 + 4, 5] = [9, 5]
// The gradient that will flow back to the gather operation will look like
// [x1, x2], it will have the same shape as the output of the SegmentSum
// operation. The differentiation step of the SegmentSum operation just
// broadcast the gradient in order to retrieve the z's shape.
// dy/dz = [x1, x1, x1, x2]
grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1)));
// stop propagation along segment_ids
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad);
// MatMulGrad helper function used to compute two MatMul operations
// based on input matrix transposition combinations.
Status MatMulGradHelper(const Scope& scope, const bool is_batch,

View File

@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/gradients/grad_testutil.h"
#include "tensorflow/cc/ops/standard_ops.h"
@ -42,9 +44,11 @@ using ops::Placeholder;
using ops::Pow;
using ops::Prod;
using ops::RealDiv;
using ops::SegmentSum;
using ops::SquaredDifference;
using ops::Sub;
using ops::Sum;
using ops::UnsafeDiv;
// TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions
@ -850,6 +854,36 @@ TEST_F(NaryGradTest, RealDiv) {
RunTest({x}, {x_shape}, {y}, {x_shape});
}
TEST_F(NaryGradTest, UnsafeDiv) {
{
TensorShape x_shape({3, 2, 5});
const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
// Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
// division errors in the numeric estimator used by the gradient checker.
const auto y = UnsafeDiv(
scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
RunTest({x}, {x_shape}, {y}, {x_shape});
}
{
// Return 0 gradient (rather than NaN) for division by zero.
const auto x = Placeholder(scope_, DT_FLOAT);
const auto zero = Const<float>(scope_, 0.0);
const auto y = UnsafeDiv(scope_, x, zero);
std::vector<Output> grad_outputs;
TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs));
ClientSession session(scope_);
std::vector<Tensor> grad_result;
TF_EXPECT_OK(
session.Run({{x, {-3.0f, 0.0f, 3.0f}}}, grad_outputs, &grad_result));
EXPECT_EQ(grad_result.size(), 1);
EXPECT_EQ(grad_result[0].NumElements(), 3);
EXPECT_EQ(grad_result[0].flat<float>()(0), 0.0f);
EXPECT_EQ(grad_result[0].flat<float>()(1), 0.0f);
EXPECT_EQ(grad_result[0].flat<float>()(2), 0.0f);
}
}
TEST_F(NaryGradTest, SquaredDifference) {
TensorShape x1_shape({3, 2, 5});
TensorShape x2_shape({2, 5});
@ -898,5 +932,14 @@ TEST_F(NaryGradTest, Prod) {
RunTest({x}, {x_shape}, {y}, {y_shape});
}
TEST_F(NaryGradTest, SegmentSum) {
TensorShape x_shape({3, 4});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
auto y = SegmentSum(scope_, x, {0, 0, 1});
// the sum is always on the first dimension
TensorShape y_shape({2, 4});
RunTest({x}, {x_shape}, {y}, {y_shape});
}
} // namespace
} // namespace tensorflow

View File

@ -170,7 +170,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
variables_directory, MetaFilename(kSavedModelVariablesFilename));
if (!Env::Default()->FileExists(variables_index_path).ok()) {
LOG(INFO) << "The specified SavedModel has no variables; no checkpoints "
"were restored.";
"were restored. File does not exist: "
<< variables_index_path;
return Status::OK();
}
const string variables_path =

View File

@ -48,6 +48,7 @@ cc_library(
"//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service/cpu:buffer_info_util",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework_internal",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/str_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
@ -36,6 +37,8 @@ namespace tfcompile {
namespace {
using BufferInfo = cpu_function_runtime::BufferInfo;
bool IsAlpha(char c) {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
}
@ -85,27 +88,36 @@ Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
return Status::OK();
}
// total_buffer_bytes returns the sum of each size in `sizes`, skipping -1
// values. There are `n` entries in `sizes`.
size_t total_buffer_bytes(const intptr_t* sizes, size_t n) {
size_t total = 0;
for (size_t i = 0; i < n; ++i) {
if (sizes[i] != -1) {
total += sizes[i];
}
}
return total;
// Returns the sum of the size of each buffer in `buffer_infos`.
size_t TotalBufferBytes(const std::vector<BufferInfo>& buffer_infos) {
return std::accumulate(buffer_infos.begin(), buffer_infos.end(), size_t{0},
[](size_t size, const BufferInfo& buffer_info) {
return size + buffer_info.size();
});
}
// Fills in arg_sizes with the byte size of each positional arg.
Status ComputeArgSizes(const CompileResult& compile_result,
std::vector<int64>* arg_sizes) {
const xla::ProgramShape& ps = compile_result.program_shape;
for (int i = 0; i < ps.parameters_size(); ++i) {
arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf(
ps.parameters(i), compile_result.pointer_size));
// Returns a vector of BufferInfo instances in `buffer_infos` that are entry
// parameter buffers.
std::vector<BufferInfo> ExtractEntryParamBufferInfos(
const std::vector<BufferInfo>& buffer_infos) {
std::vector<BufferInfo> result;
std::copy_if(buffer_infos.begin(), buffer_infos.end(),
std::back_inserter(result), [](const BufferInfo& buffer_info) {
return buffer_info.is_entry_parameter();
});
return result;
}
return Status::OK();
// Returns a vector of BufferInfo instances in `buffer_infos` that are temp
// buffers.
std::vector<BufferInfo> ExtractTempBufferInfos(
const std::vector<BufferInfo>& buffer_infos) {
std::vector<BufferInfo> result;
std::copy_if(buffer_infos.begin(), buffer_infos.end(),
std::back_inserter(result), [](const BufferInfo& buffer_info) {
return buffer_info.is_temp_buffer();
});
return result;
}
// Add (from,to) rewrite pairs based on the given shape. These rewrite pairs
@ -278,6 +290,25 @@ Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
return Status::OK();
}
// Returns a list of C++ expressions that, when executed, will construct the
// BufferInfo instances in `buffer_infos`.
std::vector<string> BufferInfosToCppExpression(
const std::vector<BufferInfo>& buffer_infos) {
std::vector<string> buffer_infos_as_strings;
std::transform(buffer_infos.begin(), buffer_infos.end(),
std::back_inserter(buffer_infos_as_strings),
[](const BufferInfo& buffer_info) {
std::pair<uint64, uint64> encoded = buffer_info.Encode();
string encoded_second_as_str =
encoded.second == ~0ULL
? "~0ULL"
: strings::StrCat(encoded.second, "ULL");
return strings::StrCat(
"::tensorflow::cpu_function_runtime::BufferInfo({",
encoded.first, "ULL, ", encoded_second_as_str, "})");
});
return buffer_infos_as_strings;
}
} // namespace
Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
@ -286,29 +317,35 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(ValidateConfig(config));
TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
const int64 result_index = compile_result.aot->result_buffer_index();
const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
if (result_index < 0 || result_index >= temp_sizes.size()) {
const std::vector<BufferInfo>& buffer_infos =
compile_result.aot->buffer_infos();
const std::vector<int32> arg_index_table =
::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
std::vector<string> buffer_infos_as_strings =
BufferInfosToCppExpression(buffer_infos);
if (result_index < 0 || result_index >= buffer_infos.size()) {
return errors::InvalidArgument("result index: ", result_index,
" is outside the range of temp sizes: [0,",
temp_sizes.size(), ")");
buffer_infos.size(), ")");
}
// Compute sizes and generate methods.
std::vector<int64> arg_sizes;
TF_RETURN_IF_ERROR(ComputeArgSizes(compile_result, &arg_sizes));
std::vector<BufferInfo> buffer_infos_for_args =
ExtractEntryParamBufferInfos(buffer_infos);
std::vector<BufferInfo> buffer_infos_for_temps =
ExtractTempBufferInfos(buffer_infos);
const xla::ProgramShape& ps = compile_result.program_shape;
string methods_arg, methods_result;
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
const std::vector<intptr_t> iarg(arg_sizes.begin(), arg_sizes.end());
const std::vector<intptr_t> itemp(temp_sizes.begin(), temp_sizes.end());
const size_t arg_bytes_aligned =
cpu_function_runtime::AlignedBufferBytes(iarg.data(), iarg.size());
const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size());
const size_t temp_bytes_aligned =
cpu_function_runtime::AlignedBufferBytes(itemp.data(), itemp.size());
const size_t temp_bytes_total =
total_buffer_bytes(itemp.data(), itemp.size());
const size_t arg_bytes_aligned = cpu_function_runtime::AlignedBufferBytes(
buffer_infos_for_args.data(), buffer_infos_for_args.size(),
/*allocate_entry_params=*/true);
const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args);
const size_t temp_bytes_aligned = cpu_function_runtime::AlignedBufferBytes(
buffer_infos_for_temps.data(), buffer_infos_for_temps.size(),
/*allocate_entry_params=*/true);
const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps);
// Create rewrite strings for namespace start and end.
string ns_start;
@ -343,8 +380,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
// calling HloProfilePrinter::profile_counters_size.
const string assign_profile_counters_size =
opts.gen_hlo_profile_printer_data
? "data->profile_counters_size = "
"data->hlo_profile_printer_data->profile_counters_size();"
? "data->set_profile_counters_size("
"data->hlo_profile_printer_data()->profile_counters_size());"
: "";
// Use a poor-man's text templating mechanism; first populate the full header
@ -414,9 +451,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static constexpr size_t kNumArgs = {{ARG_NUM}};
// Byte size of each argument buffer. There are kNumArgs entries.
static const intptr_t* ArgSizes() {
static constexpr intptr_t kArgSizes[kNumArgs] = {{{ARG_SIZES}}};
return kArgSizes;
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
}
// Returns static data used to create an XlaCompiledCpuFunction.
@ -424,16 +460,16 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
XlaCompiledCpuFunction::StaticData* data =
new XlaCompiledCpuFunction::StaticData;
data->raw_function = {{ENTRY}};
data->arg_sizes = ArgSizes();
data->num_args = kNumArgs;
data->temp_sizes = TempSizes();
data->num_temps = kNumTemps;
data->result_index = kResultIndex;
data->arg_names = StaticArgNames();
data->result_names = StaticResultNames();
data->program_shape = StaticProgramShape();
data->hlo_profile_printer_data = StaticHloProfilePrinterData();
data->set_raw_function({{ENTRY}});
data->set_buffer_infos(BufferInfos());
data->set_num_buffers(kNumBuffers);
data->set_arg_index_table(ArgIndexToBufferIndex());
data->set_num_args(kNumArgs);
data->set_result_index(kResultIndex);
data->set_arg_names(StaticArgNames());
data->set_result_names(StaticResultNames());
data->set_program_shape(StaticProgramShape());
data->set_hlo_profile_printer_data(StaticHloProfilePrinterData());
{{ASSIGN_PROFILE_COUNTERS_SIZE}}
return data;
}();
@ -482,17 +518,27 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{{METHODS_RESULT}}
private:
// Number of result and temporary buffers for the compiled computation.
static constexpr size_t kNumTemps = {{TEMP_NUM}};
// Number of buffers for the compiled computation.
static constexpr size_t kNumBuffers = {{NUM_BUFFERS}};
static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() {
static const ::tensorflow::cpu_function_runtime::BufferInfo
kBufferInfos[kNumBuffers] = {
{{BUFFER_INFOS_AS_STRING}}
};
return kBufferInfos;
}
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
{{ARG_INDEX_TABLE}}
};
return kArgIndexToBufferIndex;
}
// The 0-based index of the result tuple in the temporary buffers.
static constexpr size_t kResultIndex = {{RESULT_INDEX}};
// Byte size of each result / temporary buffer. There are kNumTemps entries.
static const intptr_t* TempSizes() {
static constexpr intptr_t kTempSizes[kNumTemps] = {{{TEMP_SIZES}}};
return kTempSizes;
}
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
@ -523,8 +569,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)},
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
{"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())},
{"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")},
{"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())},
{"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}",
@ -546,8 +592,9 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{RESULT_NAMES_CODE}}", result_names_code},
{"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
{"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())},
{"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}};
{"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())},
{"{{BUFFER_INFOS_AS_STRING}}",
str_util::Join(buffer_infos_as_strings, ",\n")}};
str_util::ReplaceAllPairs(header, rewrites);
return Status::OK();
}

View File

@ -32,6 +32,8 @@ namespace tensorflow {
namespace tfcompile {
namespace {
using ::tensorflow::cpu_function_runtime::BufferInfo;
void ExpectErrorContains(const Status& status, StringPiece str) {
EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
@ -171,8 +173,14 @@ TEST(CodegenTest, Golden) {
fetch->mutable_id()->set_node_name("fetch0");
fetch->set_name("myfetch");
CompileResult compile_result;
compile_result.aot.reset(
new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {}));
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
{},
{BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
BufferInfo::MakeTempBuffer(2),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
5, {}));
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
{
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),

View File

@ -65,9 +65,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
static constexpr size_t kNumArgs = 2;
// Byte size of each argument buffer. There are kNumArgs entries.
static const intptr_t* ArgSizes() {
static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96};
return kArgSizes;
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
}
// Returns static data used to create an XlaCompiledCpuFunction.
@ -75,16 +74,16 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
XlaCompiledCpuFunction::StaticData* data =
new XlaCompiledCpuFunction::StaticData;
data->raw_function = entry_point;
data->arg_sizes = ArgSizes();
data->num_args = kNumArgs;
data->temp_sizes = TempSizes();
data->num_temps = kNumTemps;
data->result_index = kResultIndex;
data->arg_names = StaticArgNames();
data->result_names = StaticResultNames();
data->program_shape = StaticProgramShape();
data->hlo_profile_printer_data = StaticHloProfilePrinterData();
data->set_raw_function(entry_point);
data->set_buffer_infos(BufferInfos());
data->set_num_buffers(kNumBuffers);
data->set_arg_index_table(ArgIndexToBufferIndex());
data->set_num_args(kNumArgs);
data->set_result_index(kResultIndex);
data->set_arg_names(StaticArgNames());
data->set_result_names(StaticResultNames());
data->set_program_shape(StaticProgramShape());
data->set_hlo_profile_printer_data(StaticHloProfilePrinterData());
return data;
}();
@ -215,17 +214,32 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
}
private:
// Number of result and temporary buffers for the compiled computation.
static constexpr size_t kNumTemps = 6;
// Number of buffers for the compiled computation.
static constexpr size_t kNumBuffers = 6;
static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() {
static const ::tensorflow::cpu_function_runtime::BufferInfo
kBufferInfos[kNumBuffers] = {
::tensorflow::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
};
return kBufferInfos;
}
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
1, 3
};
return kArgIndexToBufferIndex;
}
// The 0-based index of the result tuple in the temporary buffers.
static constexpr size_t kResultIndex = 5;
// Byte size of each result / temporary buffer. There are kNumTemps entries.
static const intptr_t* TempSizes() {
static constexpr intptr_t kTempSizes[kNumTemps] = {1, -1, 2, -1, 3, 120};
return kTempSizes;
}
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {
static const char* kNames[] = {"myfeed", nullptr};

View File

@ -51,11 +51,9 @@ namespace tensorflow {
namespace tfcompile {
namespace {
void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) {
for (int i = 0; i < n; ++i) {
if (sizes[i] != -1) {
memset(bufs[i], 0, sizes[i]);
}
void zero_buffers(XlaCompiledCpuFunction* computation) {
for (int i = 0; i < computation->num_args(); ++i) {
memset(computation->arg_data(i), 0, computation->arg_size(i));
}
}
@ -66,7 +64,7 @@ TEST(TEST_NAME, NoCrash) {
CPP_CLASS computation;
computation.set_thread_pool(&device);
zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
zero_buffers(&computation);
EXPECT_TRUE(computation.Run());
}
@ -80,7 +78,7 @@ void BM_NAME(int iters) {
CPP_CLASS computation;
computation.set_thread_pool(&device);
zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
zero_buffers(&computation);
testing::StartTiming();
while (--iters) {

View File

@ -44,8 +44,8 @@ using ::testing::IsSupersetOf;
TEST(TFCompileTest, Add) {
AddComp add;
EXPECT_EQ(add.arg0_data(), add.args()[0]);
EXPECT_EQ(add.arg1_data(), add.args()[1]);
EXPECT_EQ(add.arg0_data(), add.arg_data(0));
EXPECT_EQ(add.arg1_data(), add.arg_data(1));
add.arg0() = 1;
add.arg1() = 2;
@ -67,10 +67,10 @@ TEST(TFCompileTest, Add) {
EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 123);
EXPECT_EQ(add_const.arg0_data()[0], 123);
EXPECT_EQ(add_const.arg0_data(), add.args()[0]);
EXPECT_EQ(add_const.arg0_data(), add.arg_data(0));
EXPECT_EQ(add_const.arg1(), 456);
EXPECT_EQ(add_const.arg1_data()[0], 456);
EXPECT_EQ(add_const.arg1_data(), add.args()[1]);
EXPECT_EQ(add_const.arg1_data(), add.arg_data(1));
EXPECT_EQ(add_const.result0(), 579);
EXPECT_EQ(add_const.result0_data()[0], 579);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
@ -85,8 +85,8 @@ TEST(TFCompileTest, Add_SetArg) {
int32 arg_y = 32;
add.set_arg0_data(&arg_x);
add.set_arg1_data(&arg_y);
EXPECT_EQ(add.arg0_data(), add.args()[0]);
EXPECT_EQ(add.arg1_data(), add.args()[1]);
EXPECT_EQ(add.arg0_data(), add.arg_data(0));
EXPECT_EQ(add.arg1_data(), add.arg_data(1));
EXPECT_TRUE(add.Run());
EXPECT_EQ(add.error_msg(), "");
@ -97,7 +97,7 @@ TEST(TFCompileTest, Add_SetArg) {
TEST(TFCompileTest, AddWithCkpt) {
AddWithCkptComp add;
EXPECT_EQ(add.arg0_data(), add.args()[0]);
EXPECT_EQ(add.arg0_data(), add.arg_data(0));
add.arg0() = 1;
EXPECT_TRUE(add.Run());
@ -117,7 +117,7 @@ TEST(TFCompileTest, AddWithCkpt) {
EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 111);
EXPECT_EQ(add_const.arg0_data()[0], 111);
EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
EXPECT_EQ(add_const.result0(), 153);
EXPECT_EQ(add_const.result0_data()[0], 153);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
@ -125,7 +125,7 @@ TEST(TFCompileTest, AddWithCkpt) {
TEST(TFCompileTest, AddWithCkptSaver) {
AddWithCkptSaverComp add;
EXPECT_EQ(add.arg0_data(), add.args()[0]);
EXPECT_EQ(add.arg0_data(), add.arg_data(0));
add.arg0() = 1;
EXPECT_TRUE(add.Run());
@ -145,7 +145,7 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 111);
EXPECT_EQ(add_const.arg0_data()[0], 111);
EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
EXPECT_EQ(add_const.result0(), 153);
EXPECT_EQ(add_const.result0_data()[0], 153);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
@ -153,9 +153,9 @@ TEST(TFCompileTest, AddWithCkptSaver) {
TEST(TFCompileTest, Cond) {
CondComp cond;
EXPECT_EQ(cond.arg0_data(), cond.args()[0]);
EXPECT_EQ(cond.arg1_data(), cond.args()[1]);
EXPECT_EQ(cond.arg2_data(), cond.args()[2]);
EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
EXPECT_EQ(cond.arg1_data(), cond.arg_data(1));
EXPECT_EQ(cond.arg2_data(), cond.arg_data(2));
cond.arg1() = 10;
cond.arg2() = 20;
{
@ -178,8 +178,8 @@ TEST(TFCompileTest, Cond) {
TEST(TFCompileTest, Gather) {
GatherComp gather;
EXPECT_EQ(gather.arg0_data(), gather.args()[0]);
EXPECT_EQ(gather.arg1_data(), gather.args()[1]);
EXPECT_EQ(gather.arg0_data(), gather.arg_data(0));
EXPECT_EQ(gather.arg1_data(), gather.arg_data(1));
// Successful gather.
{
@ -202,12 +202,12 @@ TEST(TFCompileTest, Gather) {
EXPECT_EQ(gather_const.arg0(i), params[i]);
EXPECT_EQ(gather_const.arg0_data()[i], params[i]);
}
EXPECT_EQ(gather_const.arg0_data(), gather_const.args()[0]);
EXPECT_EQ(gather_const.arg0_data(), gather_const.arg_data(0));
for (int i = 0; i < 2; ++i) {
EXPECT_EQ(gather_const.arg1(i), indices[i]);
EXPECT_EQ(gather_const.arg1_data()[i], indices[i]);
}
EXPECT_EQ(gather_const.arg1_data(), gather_const.args()[1]);
EXPECT_EQ(gather_const.arg1_data(), gather_const.arg_data(1));
for (int i = 0; i < 2; ++i) {
EXPECT_EQ(gather_const.result0(i), results[i]);
EXPECT_EQ(gather_const.result0_data()[i], results[i]);
@ -222,8 +222,8 @@ TEST(TFCompileTest, MatMul2) {
foo::bar::MatMulComp matmul;
matmul.set_thread_pool(&device);
EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
// Test using the argN() methods.
{
@ -271,12 +271,12 @@ TEST(TFCompileTest, MatMul2) {
EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]);
EXPECT_EQ(matmul_const.arg0_data()[i], args[i]);
}
EXPECT_EQ(matmul_const.arg0_data(), matmul.args()[0]);
EXPECT_EQ(matmul_const.arg0_data(), matmul.arg_data(0));
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]);
EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]);
}
EXPECT_EQ(matmul_const.arg1_data(), matmul.args()[1]);
EXPECT_EQ(matmul_const.arg1_data(), matmul.arg_data(1));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]);
EXPECT_EQ(matmul_const.result0_data()[i], results[i]);
@ -300,8 +300,8 @@ TEST(TFCompileTest, MatMul2_SetArg) {
float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}};
matmul.set_arg0_data(&arg0);
matmul.set_arg1_data(&arg1);
EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
EXPECT_TRUE(matmul.Run());
EXPECT_EQ(matmul.error_msg(), "");
@ -319,8 +319,8 @@ TEST(TFCompileTest, MatMulAndAdd1) {
MatMulAndAddComp muladd;
muladd.set_thread_pool(&device);
EXPECT_EQ(muladd.arg0_data(), muladd.args()[0]);
EXPECT_EQ(muladd.arg1_data(), muladd.args()[1]);
EXPECT_EQ(muladd.arg0_data(), muladd.arg_data(0));
EXPECT_EQ(muladd.arg1_data(), muladd.arg_data(1));
// Test methods with positional args and results.
{
@ -346,12 +346,12 @@ TEST(TFCompileTest, MatMulAndAdd1) {
EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]);
EXPECT_EQ(muladd_const.arg0_data()[i], args[i]);
}
EXPECT_EQ(muladd_const.arg0_data(), muladd.args()[0]);
EXPECT_EQ(muladd_const.arg0_data(), muladd.arg_data(0));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]);
EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]);
}
EXPECT_EQ(muladd_const.arg1_data(), muladd.args()[1]);
EXPECT_EQ(muladd_const.arg1_data(), muladd.arg_data(1));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]);
EXPECT_EQ(muladd_const.result0_data()[i], results0[i]);
@ -387,12 +387,12 @@ TEST(TFCompileTest, MatMulAndAdd1) {
EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]);
EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]);
}
EXPECT_EQ(muladd_const.arg_x_data(), muladd.args()[0]);
EXPECT_EQ(muladd_const.arg_x_data(), muladd.arg_data(0));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]);
EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]);
}
EXPECT_EQ(muladd_const.arg_y_data(), muladd.args()[1]);
EXPECT_EQ(muladd_const.arg_y_data(), muladd.arg_data(1));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]);
EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]);
@ -407,8 +407,8 @@ TEST(TFCompileTest, MatMulAndAdd1) {
TEST(TFCompileTest, Function) {
// The function is equivalent to an addition
FunctionComp add_fn;
EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]);
EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]);
EXPECT_EQ(add_fn.arg0_data(), add_fn.arg_data(0));
EXPECT_EQ(add_fn.arg1_data(), add_fn.arg_data(1));
add_fn.arg0() = 1;
add_fn.arg1() = 2;
@ -451,8 +451,8 @@ TEST(TFCompileTest, AssertEqAndReturnDiff) {
// Assert is converted into a no-op in XLA, so there is no failure even if the
// two args are different.
AssertComp assert;
EXPECT_EQ(assert.arg0_data(), assert.args()[0]);
EXPECT_EQ(assert.arg1_data(), assert.args()[1]);
EXPECT_EQ(assert.arg0_data(), assert.arg_data(0));
EXPECT_EQ(assert.arg1_data(), assert.arg_data(1));
assert.arg0() = 2;
assert.arg1() = 1;

View File

@ -160,6 +160,7 @@ cc_library(
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:util",
@ -178,6 +179,7 @@ cc_library(
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:fifo_queue",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:no_op",
@ -186,6 +188,9 @@ cc_library(
"//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:shape_ops",
"//tensorflow/core/kernels:variable_ops",
"//tensorflow/core/kernels/data:generator_dataset_op",
"//tensorflow/core/kernels/data:iterator_ops",
"//tensorflow/core/kernels/data:prefetch_dataset_op",
],
)

View File

@ -46,6 +46,7 @@ class Predicate {
virtual string ToString() const = 0;
int64 hash() const { return hash_; }
virtual gtl::ArraySlice<Predicate*> GetOperands() const = 0;
virtual Kind kind() const = 0;
virtual ~Predicate() {}
@ -90,7 +91,8 @@ class AndPredicate : public Predicate {
Kind kind() const override { return Kind::kAnd; }
const gtl::ArraySlice<Predicate*> operands() const { return operands_; }
gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
gtl::ArraySlice<Predicate*> operands() const { return operands_; }
private:
std::vector<Predicate*> operands_;
@ -117,7 +119,8 @@ class OrPredicate : public Predicate {
}
Kind kind() const override { return Kind::kOr; }
const gtl::ArraySlice<Predicate*> operands() const { return operands_; }
gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
gtl::ArraySlice<Predicate*> operands() const { return operands_; }
private:
std::vector<Predicate*> operands_;
@ -128,17 +131,18 @@ class NotPredicate : public Predicate {
public:
explicit NotPredicate(Predicate* operand)
: Predicate(HashPredicateSequence(Kind::kNot, {operand})),
operand_(operand) {}
operands_({operand}) {}
string ToString() const override {
return strings::StrCat("~", operand()->ToString());
}
Kind kind() const override { return Kind::kNot; }
Predicate* operand() const { return operand_; }
Predicate* operand() const { return operands_[0]; }
gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
private:
Predicate* operand_;
std::array<Predicate*, 1> operands_;
};
// Represents an uninterpreted symbol in a logical predicate.
@ -158,6 +162,7 @@ class SymbolPredicate : public Predicate {
}
Kind kind() const override { return Kind::kSymbol; }
gtl::ArraySlice<Predicate*> GetOperands() const override { return {}; }
// If `must_be_true()` is true this SymbolPredicate represents the proposition
// "tensor_id() is live and evaluates to true".
@ -288,10 +293,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
if (op->kind() == pred_kind) {
// "Inline" the operands of an inner And/Or into the parent And/Or.
gtl::ArraySlice<Predicate*> operands =
is_and ? dynamic_cast<AndPredicate*>(op)->operands()
: dynamic_cast<OrPredicate*>(op)->operands();
for (Predicate* subop : operands) {
for (Predicate* subop : op->GetOperands()) {
if (simplified_ops_set.insert(subop).second) {
simplified_ops.push_back(subop);
}

View File

@ -1161,8 +1161,7 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef(
strings::StrCat("replace_encapsulate_fdef_", name), fdef);
}
TF_RETURN_IF_ERROR(library->RemoveFunction(name));
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
return Status::OK();
}

View File

@ -16,6 +16,7 @@ cc_library(
"//tensorflow/compiler/jit:xla_device",
"//tensorflow/compiler/jit:xla_launch_util",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
@ -199,7 +200,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
run_options.set_stream(stream);
run_options.set_allocator(xla_allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(ctx->step_id());
run_options.set_rng_seed(GetXLARandomSeed());
Env* env = Env::Default();
auto start_time = env->NowMicros();

View File

@ -296,7 +296,7 @@ Status XlaCompilationCache::CompileImpl(
// protect the contents of the cache entry.
Entry* entry;
{
mutex_lock lock(mu_);
mutex_lock lock(compile_cache_mu_);
// Find or create a cache entry.
std::unique_ptr<Entry>& e = cache_[signature];
if (!e) {
@ -312,6 +312,8 @@ Status XlaCompilationCache::CompileImpl(
if (!entry->compiled) {
VLOG(1) << "Compilation cache miss for signature: "
<< SignatureDebugString(signature);
tensorflow::Env* env = tensorflow::Env::Default();
const uint64 compile_start_us = env->NowMicros();
// Do the actual JIT compilation without holding the lock (it can take
// a long time.)
std::vector<XlaCompiler::Argument> args;
@ -334,6 +336,26 @@ Status XlaCompilationCache::CompileImpl(
CHECK_EQ(entry->executable.get(), nullptr);
entry->compilation_status =
BuildExecutable(options, entry->compilation_result, &entry->executable);
const uint64 compile_end_us = env->NowMicros();
const uint64 compile_time_us = compile_end_us - compile_start_us;
{
mutex_lock lock(compile_stats_mu_);
auto it = compile_stats_.emplace(function.name(), CompileStats{}).first;
it->second.compile_count++;
it->second.cumulative_compile_time_us += compile_time_us;
VLOG(1) << "compiled " << function.name() << " "
<< it->second.compile_count
<< " times, compile time: " << compile_time_us
<< " us, cumulative: " << it->second.cumulative_compile_time_us
<< " us ("
<< tensorflow::strings::HumanReadableElapsedTime(compile_time_us /
1.0e6)
<< " / "
<< tensorflow::strings::HumanReadableElapsedTime(
it->second.cumulative_compile_time_us / 1.0e6)
<< ")";
}
}
TF_RETURN_IF_ERROR(entry->compilation_status);
*compilation_result = &entry->compilation_result;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
@ -150,9 +151,22 @@ class XlaCompilationCache : public ResourceBase {
std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu);
};
mutex mu_;
std::unordered_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
GUARDED_BY(mu_);
mutex compile_cache_mu_;
gtl::FlatMap<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
GUARDED_BY(compile_cache_mu_);
struct CompileStats {
// Number of times the cluster has been (re-)compiled.
int64 compile_count = 0;
// Cumulative time spent compiling the cluster.
int64 cumulative_compile_time_us = 0;
};
mutex compile_stats_mu_;
// Maps cluster names to compilation statistics for said cluster.
gtl::FlatMap<string, CompileStats> compile_stats_
GUARDED_BY(compile_stats_mu_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
};

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -71,7 +72,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
run_options.set_stream(stream);
run_options.set_allocator(client->backend().memory_allocator());
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(ctx->step_id());
run_options.set_rng_seed(GetXLARandomSeed());
xla::StatusOr<xla::ScopedShapedBuffer> run_result =
executable->Run(launch_context.arguments(), run_options);

View File

@ -211,17 +211,18 @@ XlaDevice::XlaDevice(
use_multiple_streams),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
platform_(platform),
use_multiple_streams_(use_multiple_streams),
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name;
VLOG(1) << "Created XLA device " << jit_device_name << " " << this;
}
XlaDevice::~XlaDevice() {
if (gpu_device_info_ != nullptr) {
gpu_device_info_->default_context->Unref();
VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
mutex_lock lock(mu_);
if (device_context_) {
device_context_->Unref();
}
}
@ -237,6 +238,11 @@ xla::LocalClient* XlaDevice::client() const {
}
Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
mutex_lock lock(mu_);
return GetAllocatorLocked(attr);
}
Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) {
if (attr.on_host()) {
return cpu_allocator();
}
@ -249,83 +255,105 @@ Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
return xla_allocator_;
}
xla::StatusOr<se::Stream*> XlaDevice::GetStream() {
if (!stream_) {
xla::Backend* backend = client()->mutable_backend();
TF_ASSIGN_OR_RETURN(stream_, backend->BorrowStream(device_ordinal_));
}
return stream_.get();
Status XlaDevice::EnsureDeviceContextOk() {
mutex_lock lock(mu_);
return GetDeviceContextLocked().status();
}
xla::StatusOr<se::Stream*> XlaDevice::GetDeviceToHostStream() {
if (!use_multiple_streams_) {
return GetStream();
Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
const string& name,
xla::StreamPool::Ptr* stream,
bool* stream_was_changed) {
if (!(*stream) || !(*stream)->ok()) {
TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_));
VLOG(1) << "XlaDevice " << this << " new " << name << " "
<< (*stream)->DebugStreamPointers();
*stream_was_changed = true;
}
if (!device_to_host_stream_) {
xla::Backend* backend = client()->mutable_backend();
TF_ASSIGN_OR_RETURN(device_to_host_stream_,
backend->BorrowStream(device_ordinal_));
}
return device_to_host_stream_.get();
}
xla::StatusOr<se::Stream*> XlaDevice::GetHostToDeviceStream() {
if (!use_multiple_streams_) {
return GetStream();
}
if (!host_to_device_stream_) {
xla::Backend* backend = client()->mutable_backend();
TF_ASSIGN_OR_RETURN(host_to_device_stream_,
backend->BorrowStream(device_ordinal_));
}
return host_to_device_stream_.get();
}
Status XlaDevice::CreateAndSetGpuDeviceInfo() {
if (gpu_device_info_ == nullptr) {
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
// Call GetAllocator for the side-effect of ensuring the allocator
// is created.
GetAllocator({});
// XlaDevice owns both gpu_device_info_ and
// gpu_device_info_->default_context.
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
gpu_device_info_->stream = stream;
gpu_device_info_->default_context =
new XlaDeviceContext(stream, stream, stream, client(),
transfer_as_literal_, shape_representation_fn_);
set_tensorflow_gpu_device_info(gpu_device_info_.get());
}
return Status::OK();
}
xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
xla::Backend* backend = client()->mutable_backend();
// Ensure all our streams are valid, borrowing new streams if necessary.
bool need_new_device_context = !device_context_;
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
&need_new_device_context));
se::Stream* host_to_device_stream = stream_.get();
se::Stream* device_to_host_stream = stream_.get();
if (use_multiple_streams_) {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
&host_to_device_stream_,
&need_new_device_context));
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream",
&device_to_host_stream_,
&need_new_device_context));
host_to_device_stream = host_to_device_stream_.get();
device_to_host_stream = device_to_host_stream_.get();
}
if (!need_new_device_context) {
return device_context_;
}
// At this point we know we need a new device context.
// Call GetAllocator for the side-effect of ensuring the allocator is created.
GetAllocatorLocked({});
if (device_context_) {
device_context_->Unref();
}
device_context_ = new XlaDeviceContext(
stream_.get(), host_to_device_stream, device_to_host_stream, client(),
transfer_as_literal_, shape_representation_fn_);
VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
<< device_context_;
// Create and set a new GpuDeviceInfo, if necessary.
//
// TODO(b/78232898): This isn't thread-safe; there is a race between the call
// to set_tensorflow_gpu_device_info() with ops that call the getter
// tensorflow_gpu_device_info(). This isn't trivially fixed by adding locking
// to those methods; see the bug for details. Our only saving grace at the
// moment is that this race doesn't seem to occur in practice.
if (use_gpu_device_info_) {
auto gpu_device_info = MakeUnique<GpuDeviceInfo>();
gpu_device_info->stream = stream_.get();
gpu_device_info->default_context = device_context_;
set_tensorflow_gpu_device_info(gpu_device_info.get());
gpu_device_info_ = std::move(gpu_device_info);
VLOG(1) << "XlaDevice " << this << " new GpuDeviceInfo "
<< gpu_device_info_.get();
}
return device_context_;
}
Status XlaDevice::UseGpuDeviceInfo() {
mutex_lock lock(mu_);
use_gpu_device_info_ = true;
return GetDeviceContextLocked().status();
}
Status XlaDevice::FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) {
VLOG(1) << "XlaDevice::FillContextMap";
device_context_map->resize(graph->num_node_ids());
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
GetDeviceToHostStream());
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
GetHostToDeviceStream());
mutex_lock lock(mu_);
TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context,
GetDeviceContextLocked());
// Call GetAllocator for the side-effect of ensuring the allocator is created.
GetAllocator({});
auto ctx = new XlaDeviceContext(
stream, host_to_device_stream, device_to_host_stream, client(),
transfer_as_literal_, shape_representation_fn_);
device_context_map->resize(graph->num_node_ids());
for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
ctx->Ref();
(*device_context_map)[n->id()] = ctx;
device_context->Ref();
(*device_context_map)[n->id()] = device_context;
}
ctx->Unref();
return Status::OK();
}
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":"
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string();
// When Xprof profiling is off (which is the default), constructing the
// activity is simple enough that its overhead is negligible.
@ -336,7 +364,7 @@ void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) {
VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
<< op_kernel->type_string();
tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
op_kernel->IsExpensive());
@ -358,17 +386,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
if (alloc_attrs.on_host()) {
*tensor = parsed;
} else {
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
mutex_lock lock(mu_);
TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context,
GetDeviceContextLocked());
Allocator* allocator = GetAllocatorLocked(alloc_attrs);
Tensor copy(allocator, parsed.dtype(), parsed.shape());
Notification n;
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
GetDeviceToHostStream());
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
GetHostToDeviceStream());
XlaTransferManager manager(stream, host_to_device_stream,
device_to_host_stream, client(),
transfer_as_literal_, shape_representation_fn_);
manager.CopyCPUTensorToDevice(&parsed, this, &copy,
device_context->CopyCPUTensorToDevice(&parsed, this, &copy,
[&n, &status](const Status& s) {
status = s;
n.Notify();

View File

@ -25,6 +25,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
#include "tensorflow/compiler/jit/xla_device_context.h"
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -40,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace tensorflow {
@ -117,62 +119,85 @@ class XlaDevice : public LocalDevice {
const PaddedShapeFn& padded_shape_fn);
~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
Allocator* GetAllocator(AllocatorAttributes attr) override
LOCKS_EXCLUDED(mu_);
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
Status Sync() override { return Status::OK(); }
Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override;
DeviceContextMap* device_context_map) override
LOCKS_EXCLUDED(mu_);
Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) override;
Tensor* tensor) override LOCKS_EXCLUDED(mu_);
xla::LocalClient* client() const;
const Metadata& metadata() { return xla_metadata_; }
xla::StatusOr<se::Stream*> GetStream();
xla::StatusOr<se::Stream*> GetHostToDeviceStream();
xla::StatusOr<se::Stream*> GetDeviceToHostStream();
// If not already set, create and set GpuDeviceInfo.
// Not thread-safe
Status CreateAndSetGpuDeviceInfo();
// Ensures the DeviceContext associated with this XlaDevice is created and
// valid (i.e. all streams are ok). If any state is not valid, a new
// DeviceContext will be created.
//
// TODO(b/111859745): The Eager context needs to call this method to recover
// from failures.
Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_);
// Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra
// information for GPU and TPU devices.
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
private:
xla::LocalClient* client() const;
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
xla::StreamPool::Ptr* stream,
bool* stream_was_changed)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutex mu_;
// The metadata of this XlaDevice.
const Metadata xla_metadata_;
// Which hardware device in the client's platform this XlaDevice controls.
const int device_ordinal_;
// The name of the device that is used to compile Ops for this XlaDevice.
DeviceType jit_device_name_;
const DeviceType jit_device_name_;
// The platform for this device.
se::Platform* const platform_; // Not owned.
// Memory allocator associated with this device.
Allocator* xla_allocator_; // Not owned.
se::Platform* platform_; // Not owned.
Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned.
// Stream associated with this device. Operations enqueued on this
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
xla::StreamPool::Ptr stream_;
// If true, only stream_ is valid and all computation and transfers use
// stream_. If false, computation is performed by stream_ and transfers are
xla::StreamPool::Ptr stream_ GUARDED_BY(mu_);
// If false, only stream_ is valid and all computation and transfers use
// stream_. If true, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_host_stream.
bool use_multiple_streams_;
const bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
xla::StreamPool::Ptr host_to_device_stream_;
xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_);
// If use_multiple_streams_, device to host transfers are performed using this
// stream.
xla::StreamPool::Ptr device_to_host_stream_;
xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_);
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
bool transfer_as_literal_;
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
const bool transfer_as_literal_;
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
// If set, holds default device context (that we must Unref)
// and its stream.
std::unique_ptr<GpuDeviceInfo> gpu_device_info_;
// The device context accessed by all users of the XlaDevice, set by calls to
// EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is
// also filled in to that struct. XlaDeviceContext is a ref-counted object.
XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr;
// Holds extra information for GPU and TPU devices, e.g. the device context.
bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
};
// Builds OpKernel registrations on 'device' for the JIT operators

View File

@ -101,34 +101,27 @@ Status XlaTransferManager::TransferLiteralToDevice(
// Unref the host tensor, and capture the literal shared_ptr too so it goes
// out of scope when the lambda completes.
host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); });
return Status::OK();
}
void XlaTransferManager::TransferLiteralFromDevice(
Tensor* host_tensor, const Tensor& device_tensor,
const StatusCallback& done) const {
xla::MutableBorrowingLiteral literal;
TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal));
const xla::ShapedBuffer& shaped_buffer =
XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
TensorReference ref(device_tensor);
transfer_manager_->TransferLiteralFromDevice(
device_to_host_stream_, shaped_buffer,
[=, &shaped_buffer](
xla::StatusOr<std::unique_ptr<xla::Literal> > literal_or) {
device_to_host_stream_, shaped_buffer, literal,
[=, &shaped_buffer, &literal](xla::Status status) {
ref.Unref();
done([&]() -> Status {
TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or));
VLOG(1) << "Transfer from device as literal: " << literal->ToString()
VLOG(1) << "Transfer from device as literal: " << literal.ToString()
<< " " << shaped_buffer.ToString();
Tensor tensor;
TF_RETURN_IF_ERROR(
LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
// Reshape the tensor back to its declared shape.
Status status;
if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
status = errors::Internal(
"Tensor::CopyFrom failed when copying from XLA device to CPU");
}
return status;
}());
});

View File

@ -23,7 +23,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cast_op.h"
#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/kernels/control_flow_ops.h"
#include "tensorflow/core/kernels/data/generator_dataset_op.h"
#include "tensorflow/core/kernels/data/iterator_ops.h"
#include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
#include "tensorflow/core/kernels/fifo_queue.h"
#include "tensorflow/core/kernels/function_ops.h"
#include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/no_op.h"
@ -166,7 +170,69 @@ class XlaAssignVariableOp : public AsyncOpKernel {
QueueIsClosedOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp);
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \
\
REGISTER_KERNEL_BUILDER( \
Name(kArgOp).Device(DEVICE).HostMemory("output").TypeConstraint("T", \
TYPES), \
ArgOp); \
REGISTER_KERNEL_BUILDER(Name(kArgOp) \
.Device(DEVICE) \
.HostMemory("output") \
.TypeConstraint<ResourceHandle>("T"), \
ArgOp); \
\
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
.Device(DEVICE) \
.TypeConstraint("T", TYPES) \
.HostMemory("input"), \
RetvalOp); \
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
.Device(DEVICE) \
.TypeConstraint<ResourceHandle>("T") \
.HostMemory("input"), \
RetvalOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \
GeneratorDatasetOp); \
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \
.Device(DEVICE) \
.HostMemory("buffer_size") \
.HostMemory("input_dataset") \
.HostMemory("handle"), \
PrefetchDatasetOp); \
\
REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \
IteratorHandleOp); \
REGISTER_KERNEL_BUILDER( \
Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \
MakeIteratorOp); \
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \
AnonymousIteratorHandleOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
IteratorGetNextOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \
.Device(DEVICE) \
.HostMemory("string_handle"), \
IteratorToStringHandleOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \
.Device(DEVICE) \
.HostMemory("string_handle"), \
IteratorFromStringHandleOp); \
REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \
.Device(DEVICE) \
.HostMemory("output") \
.TypeConstraint<string>("T"), \
ArgOp); \
REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp) \
.Device(DEVICE) \
.TypeConstraint<string>("T") \
.HostMemory("input"), \
RetvalOp);
// TODO(phawkins): currently we do not register the QueueEnqueueMany,
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read

View File

@ -59,7 +59,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
}
// TODO(b/78468222): Uncomment after fixing this bug
// status = device->CreateAndSetGpuDeviceInfo();
// status = device->UseGpuDeviceInfo();
// if (!status.ok()) {
// errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT,
// " device");

View File

@ -691,11 +691,7 @@ tf_xla_py_test(
size = "small",
srcs = ["random_ops_test.py"],
disabled_backends = [
# TODO(b/110300529): RngNormal doesn't return values with the expected variance
"cpu",
"cpu_ondemand",
# TODO(b/31361304): enable RNG ops on GPU when parallelized.
"gpu",
],
deps = [
":xla_test",

View File

@ -52,6 +52,9 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
if dtype == np.float16:
continue
with self.test_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
@ -91,6 +94,9 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRate(self):
for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
if dtype == np.float16:
continue
with self.test_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
@ -130,6 +136,9 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testSharing(self):
for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
if dtype == np.float16:
continue
with self.test_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)

View File

@ -32,6 +32,7 @@ from tensorflow.python.layers import convolutional
from tensorflow.python.layers import pooling
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@ -122,6 +123,14 @@ class EagerTest(xla_test.XLATestCase):
with self.test_scope():
self.assertAllEqual(2, array_ops.identity(2))
def testRandomOps(self):
with self.test_scope():
tensor = gen_random_ops.random_uniform((2, 2), dtypes.float32)
row0 = tensor[0].numpy()
row1 = tensor[1].numpy()
# It should be very unlikely to rng to generate two equal rows.
self.assertFalse((row0 == row1).all())
def testIdentityOnVariable(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(True)

View File

@ -57,7 +57,8 @@ class RandomOpsTest(xla_test.XLATestCase):
def testRandomUniformIsNotConstant(self):
def rng(dtype):
return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000)
dtype = dtypes.as_dtype(dtype)
return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=dtype.max)
for dtype in self._random_types():
self._testRngIsNotConstant(rng, dtype)
@ -73,6 +74,11 @@ class RandomOpsTest(xla_test.XLATestCase):
def testRandomUniformIsInRange(self):
for dtype in self._random_types():
# TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is
# fixed.
if (self.device in ["XLA_GPU", "XLA_CPU"
]) and (dtype in [dtypes.bfloat16, dtypes.half]):
continue
with self.test_session() as sess:
with self.test_scope():
x = random_ops.random_uniform(
@ -95,7 +101,7 @@ class RandomOpsTest(xla_test.XLATestCase):
for dtype in [dtypes.float32]:
with self.test_session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype, seed=42)
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
y = sess.run(x)
def normal_cdf(x):
@ -124,20 +130,23 @@ class RandomOpsTest(xla_test.XLATestCase):
# Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
actual_mean = np.mean(y)
self.assertAllClose(actual_mean, expected_mean, atol=2e-4)
self.assertAllClose(actual_mean, expected_mean, atol=2e-3)
expected_median = mu + probit(
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
actual_median = np.median(y)
self.assertAllClose(actual_median, expected_median, atol=8e-4)
self.assertAllClose(actual_median, expected_median, atol=1e-2)
expected_variance = sigma**2 * (1 + (
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
(normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
actual_variance = np.var(y)
self.assertAllClose(actual_variance, expected_variance, rtol=3e-4)
self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3)
def testShuffle1d(self):
# TODO(b/26783907): this test requires the CPU backend to implement sort.
if self.device in ["XLA_CPU"]:
return
with self.test_session() as sess:
with self.test_scope():
x = math_ops.range(1 << 16)

View File

@ -361,6 +361,12 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([[-0.05, 6.05, 5]], dtype=dtype),
expected=np.array([[0, 6, 5]], dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.softmax,
np.array([1, 2, 3, 4], dtype=dtype),
expected=np.array([0.032058604, 0.087144323, 0.23688284, 0.64391428],
dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.softmax,
np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
@ -369,6 +375,14 @@ class UnaryOpsTest(xla_test.XLATestCase):
[0.032058604, 0.087144323, 0.23688284, 0.64391428]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.softmax,
np.array([[[1, 1], [1, 1]], [[1, 2], [3, 4]]], dtype=dtype),
expected=np.array(
[[[0.5, 0.5], [0.5, 0.5]],
[[0.26894142, 0.73105858], [0.26894142, 0.73105858]]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.softsign,
np.array([[-2, -1, 0, 1, 2]], dtype=dtype),

View File

@ -21,6 +21,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_control_flow_ops
@ -47,6 +49,34 @@ class XlaDeviceTest(xla_test.XLATestCase):
result = sess.run(z, {x: inputs})
self.assertAllCloseAccordingToType(result, inputs + inputs)
def testCopiesOfUnsupportedTypesFailGracefully(self):
"""Tests that copies of unsupported types don't crash."""
test_types = set([
np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32,
np.int64, np.float16, np.float32, np.float16,
dtypes.bfloat16.as_numpy_dtype
])
shape = (10, 10)
for unsupported_dtype in test_types - self.all_types:
with self.test_session() as sess:
with ops.device("CPU"):
x = array_ops.placeholder(unsupported_dtype, shape)
with self.test_scope():
y, = array_ops.identity_n([x])
with ops.device("CPU"):
z = array_ops.identity(y)
inputs = np.random.randint(-100, 100, shape)
inputs = inputs.astype(unsupported_dtype)
# Execution should either succeed or raise an InvalidArgumentError,
# but not crash. Even "unsupported types" may succeed here since some
# backends (e.g., the CPU backend) are happy to handle buffers of
# unsupported types, even if they cannot compute with them.
try:
sess.run(z, {x: inputs})
except errors.InvalidArgumentError:
pass
def testControlTrigger(self):
with self.test_session() as sess:
with self.test_scope():

View File

@ -95,6 +95,10 @@ cc_library(
name = "cpu_function_runtime",
srcs = ["cpu_function_runtime.cc"],
hdrs = ["cpu_function_runtime.h"],
visibility = [
"//tensorflow/compiler/aot:__pkg__",
"//tensorflow/compiler/xla/service/cpu:__pkg__",
],
deps = [
# Keep dependencies to a minimum here; this library is used in every AOT
# binary produced by tfcompile.
@ -144,6 +148,7 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/service/cpu:buffer_info_util",
"//tensorflow/compiler/xla/service/cpu:cpu_executable",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -55,19 +55,26 @@ size_t align_to(size_t n, size_t align) {
} // namespace
namespace cpu_function_runtime {
size_t AlignedBufferBytes(const intptr_t* sizes, size_t n) {
size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n,
bool allocate_entry_params) {
size_t total = 0;
for (size_t i = 0; i < n; ++i) {
if (sizes[i] > 0) {
total += align_to(sizes[i], kAlign);
bool should_allocate =
buffer_infos[i].is_temp_buffer() ||
(buffer_infos[i].is_entry_parameter() && allocate_entry_params);
if (should_allocate) {
total += align_to(buffer_infos[i].size(), kAlign);
}
}
return total;
}
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n,
bool allocate_entry_params, void** bufs,
bool annotate_initialized) {
const size_t total = AlignedBufferBytes(sizes, n);
const size_t total =
AlignedBufferBytes(buffer_infos, n, allocate_entry_params);
void* contiguous = nullptr;
if (total > 0) {
contiguous = aligned_malloc(total, kAlign);
@ -79,13 +86,14 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
}
uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous);
for (size_t i = 0; i < n; ++i) {
if (sizes[i] < 0) {
// bufs[i] is either a constant, an entry parameter or a thread local
// allocation.
bufs[i] = nullptr;
} else {
bool should_allocate =
buffer_infos[i].is_temp_buffer() ||
(buffer_infos[i].is_entry_parameter() && allocate_entry_params);
if (should_allocate) {
bufs[i] = reinterpret_cast<void*>(pos);
pos += align_to(sizes[i], kAlign);
pos += align_to(buffer_infos[i].size(), kAlign);
} else {
bufs[i] = nullptr;
}
}
return contiguous;

View File

@ -18,29 +18,142 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include <cassert>
namespace tensorflow {
namespace cpu_function_runtime {
// Stores information about one buffer used by an XLA:CPU compiled function.
// These buffers are used for holding inputs to the computation, outputs from
// the computation and as temporary scratch space.
class BufferInfo {
public:
// Creates a BufferInfo from a serialized encoding generated by `Encode`.
explicit BufferInfo(std::pair<uint64, uint64> encoding)
: entry_param_number_(encoding.second) {
Kind kind;
uint64 size;
Unpack(encoding.first, &kind, &size);
kind_ = kind;
size_ = size;
}
// Returns true if this buffer stores a constant. These never need to be
// allocated by the runtime.
bool is_constant() const { return kind() == Kind::kConstant; }
// Returns true if this buffer stores an entry parameter. These may or may
// not need to be allocated by the runtime, depending on
// XlaCompiledCpuFunction::AllocMode.
bool is_entry_parameter() const { return kind() == Kind::kEntryParameter; }
// Returns the entry parameter number of this buffer.
uint64 entry_parameter_number() const {
assert(is_entry_parameter());
return entry_param_number_;
}
// Returns true if this buffer is temporary scratch space required by the XLA
// computations. These are always allocated by the runtime.
bool is_temp_buffer() const { return kind() == Kind::kTempBuffer; }
// Returns true if this buffer is allocated on the C stack or into registers.
// These buffers are never allocated by the runtime.
bool is_on_stack_buffer() const { return kind() == Kind::kOnStackBuffer; }
// Returns the size for this buffer.
uint64 size() const { return size_; }
// Encodes this BufferInfo into two 64 bit integers that can be used to
// reconstruct the BufferInfo later using the constructor. We need this
// because we use BufferInfo in places where using protocol buffers would
// negatively impact binary size.
std::pair<uint64, uint64> Encode() const {
static_assert(sizeof(*this) == 16, "");
uint64 upper = Pack(kind(), size_);
uint64 lower = entry_param_number_;
return {upper, lower};
}
bool operator==(const BufferInfo& buffer_info) const {
if (kind() != buffer_info.kind() || size() != buffer_info.size()) {
return false;
}
return !is_entry_parameter() ||
entry_parameter_number() == buffer_info.entry_parameter_number();
}
// Factory methods:
static BufferInfo MakeTempBuffer(uint64 size) {
return BufferInfo(Kind::kTempBuffer, /*size=*/size,
/*entry_param_number=*/-1);
}
static BufferInfo MakeConstant(uint64 size) {
return BufferInfo(Kind::kConstant, /*size=*/size,
/*entry_param_number=*/-1);
}
static BufferInfo MakeEntryParameter(uint64 size, uint64 param_number) {
return BufferInfo(Kind::kEntryParameter, /*size=*/size,
/*entry_param_number=*/param_number);
}
static BufferInfo MakeOnStackBuffer(uint64 size) {
return BufferInfo(Kind::kOnStackBuffer, /*size=*/size,
/*entry_param_number=*/-1);
}
private:
BufferInfo() = default;
enum class Kind : unsigned {
kConstant,
kTempBuffer,
kEntryParameter,
kOnStackBuffer
};
Kind kind() const { return static_cast<Kind>(kind_); }
explicit BufferInfo(Kind kind, uint64 size, uint64 entry_param_number)
: kind_(kind), size_(size), entry_param_number_(entry_param_number) {}
static uint64 Pack(Kind kind, uint64 size) {
return (static_cast<uint64>(size) << 2) | static_cast<uint64>(kind);
}
static void Unpack(uint64 packed, Kind* kind, uint64* size) {
*size = packed >> 2;
*kind = static_cast<Kind>((packed << 62) >> 62);
}
Kind kind_ : 2;
uint64 size_ : 62;
int64 entry_param_number_;
};
// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment.
constexpr size_t kAlign = 64;
// AlignedBufferBytes returns the sum of each size in `sizes`, skipping -1
// values. There are `n` entries in `sizes`. Each buffer is aligned to
// kAlign byte boundaries.
size_t AlignedBufferBytes(const intptr_t* sizes, size_t n);
// AlignedBufferBytes returns the sum of the size of each buffer in
// `buffer_infos`, skipping constants, on-stack buffers and, if
// allocate_entry_params is false, entry parameters. There are `n` entries in
// `buffer_infos`. Each buffer is aligned to kAlign byte boundaries.
size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n,
bool allocate_entry_params);
// MallocContiguousBuffers allocates buffers for use by the entry point
// generated by tfcompile. `sizes` is an array of byte sizes for each buffer,
// where -1 causes the buffer pointer to be nullptr. There are `n` entries in
// `sizes`. If `annotate_initialized` is set, the allocated memory will be
// annotated as having been initialized - this is useful when allocating
// temporary buffers.
// generated by tfcompile. There are `n` entries in `buffer_infos`. If
// `annotate_initialized` is set, the allocated memory will be annotated as
// having been initialized - this is useful when allocating temporary buffers.
// If allocate_entry_params is true then allocates temp buffers and entry
// parameters, otherwise allocated only temp buffers. Slots in `bufs`
// corresponding to unallocated buffers are set to nullptr.
//
// A single contiguous block of memory is allocated, and portions of it are
// parceled out into `bufs`, which must have space for `n` entries. Returns
// the head of the allocated contiguous block, which should be passed to
// FreeContiguous when the buffers are no longer in use.
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n,
bool allocate_entry_params, void** bufs,
bool annotate_initialized);
// FreeContiguous frees the contiguous block of memory allocated by

View File

@ -21,6 +21,8 @@ limitations under the License.
namespace tensorflow {
namespace {
using cpu_function_runtime::BufferInfo;
TEST(XlaCompiledCpuFunctionTest, AlignmentValue) {
// We've chosen 64 byte alignment for the tfcompile runtime to mimic the
// regular tensorflow allocator, which was chosen to play nicely with Eigen.
@ -30,20 +32,51 @@ TEST(XlaCompiledCpuFunctionTest, AlignmentValue) {
EXPECT_EQ(cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment);
}
std::vector<BufferInfo> SizesToBufferInfos(const intptr_t* sizes, size_t n) {
std::vector<BufferInfo> buffer_infos;
std::transform(sizes, sizes + n, std::back_inserter(buffer_infos),
[&](intptr_t size) {
if (size == -1) {
// Use a dummy on-stack buffer allocation to indicat the
// the current slot does not need an allocation.
int64 on_stack_buffer_size = 4;
return BufferInfo::MakeOnStackBuffer(on_stack_buffer_size);
}
return BufferInfo::MakeTempBuffer(size);
});
return buffer_infos;
}
// Simple wrappers to make writing tests more ergonomic.
size_t AlignedBufferBytesFromSizes(const intptr_t* sizes, size_t n) {
std::vector<BufferInfo> buffer_infos = SizesToBufferInfos(sizes, n);
return AlignedBufferBytes(buffer_infos.data(), n,
/*allocate_entry_params=*/false);
}
void* MallocContiguousBuffersFromSizes(const intptr_t* sizes, size_t n,
void** bufs, bool annotate_initialized) {
std::vector<BufferInfo> buffer_infos = SizesToBufferInfos(sizes, n);
return MallocContiguousBuffers(buffer_infos.data(), n,
/*allocate_entry_params=*/false, bufs,
annotate_initialized);
}
TEST(XlaCompiledCpuFunctionTest, AlignedBufferBytes) {
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(nullptr, 0), 0);
EXPECT_EQ(AlignedBufferBytesFromSizes(nullptr, 0), 0);
static constexpr intptr_t sizesA[1] = {-1};
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesA, 1), 0);
EXPECT_EQ(AlignedBufferBytesFromSizes(sizesA, 1), 0);
static constexpr intptr_t sizesB[1] = {3};
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesB, 1), 64);
EXPECT_EQ(AlignedBufferBytesFromSizes(sizesB, 1), 64);
static constexpr intptr_t sizesC[1] = {32};
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesC, 1), 64);
EXPECT_EQ(AlignedBufferBytesFromSizes(sizesC, 1), 64);
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesD, 7), 320);
EXPECT_EQ(AlignedBufferBytesFromSizes(sizesD, 7), 320);
}
void* add_ptr(void* base, uintptr_t delta) {
@ -56,15 +89,14 @@ void* add_ptr(void* base, uintptr_t delta) {
// free. We also check the contiguous property.
TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test empty sizes.
void* base =
cpu_function_runtime::MallocContiguousBuffers(nullptr, 0, nullptr, false);
void* base = MallocContiguousBuffersFromSizes(nullptr, 0, nullptr, false);
EXPECT_EQ(base, nullptr);
cpu_function_runtime::FreeContiguous(base);
// Test non-empty sizes with 0 sum.
static constexpr intptr_t sizesA[1] = {-1};
void* bufA[1];
base = cpu_function_runtime::MallocContiguousBuffers(sizesA, 1, bufA, false);
base = MallocContiguousBuffersFromSizes(sizesA, 1, bufA, false);
EXPECT_EQ(base, nullptr);
EXPECT_EQ(bufA[0], nullptr);
cpu_function_runtime::FreeContiguous(base);
@ -72,7 +104,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test non-empty sizes with non-0 sum.
static constexpr intptr_t sizesB[1] = {3};
void* bufB[1];
base = cpu_function_runtime::MallocContiguousBuffers(sizesB, 1, bufB, false);
base = MallocContiguousBuffersFromSizes(sizesB, 1, bufB, false);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufB[0], add_ptr(base, 0));
char* bufB0_bytes = static_cast<char*>(bufB[0]);
@ -84,7 +116,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test non-empty sizes with non-0 sum, and annotate_initialized.
static constexpr intptr_t sizesC[1] = {3};
void* bufC[1];
base = cpu_function_runtime::MallocContiguousBuffers(sizesC, 1, bufC, true);
base = MallocContiguousBuffersFromSizes(sizesC, 1, bufC, true);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufC[0], add_ptr(base, 0));
char* bufC0_bytes = static_cast<char*>(bufC[0]);
@ -96,7 +128,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test mixed sizes.
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
void* bufD[7];
base = cpu_function_runtime::MallocContiguousBuffers(sizesD, 7, bufD, false);
base = MallocContiguousBuffersFromSizes(sizesD, 7, bufD, false);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufD[0], add_ptr(base, 0));
EXPECT_EQ(bufD[1], nullptr);
@ -117,5 +149,23 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
cpu_function_runtime::FreeContiguous(base);
}
void CheckRoundTripIsOk(const BufferInfo& buffer_info) {
BufferInfo round_trip(buffer_info.Encode());
ASSERT_EQ(round_trip, buffer_info);
}
TEST(XlaCompiledCpuFunctionTest, BufferInfoTest) {
CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(0));
CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(4));
CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(0));
CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(4));
CheckRoundTripIsOk(BufferInfo::MakeConstant(0));
CheckRoundTripIsOk(BufferInfo::MakeConstant(4));
CheckRoundTripIsOk(
BufferInfo::MakeEntryParameter(/*size=*/0, /*param_number=*/4));
CheckRoundTripIsOk(
BufferInfo::MakeEntryParameter(/*size=*/4, /*param_number=*/0));
}
} // namespace
} // namespace tensorflow

View File

@ -6,6 +6,10 @@ package(
load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
)
tf_kernel_library(
name = "xla_ops",
@ -129,6 +133,7 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/lib:pooling",
"//tensorflow/compiler/xla/client/lib:prng",
"//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/core:framework",
@ -153,8 +158,14 @@ tf_kernel_library(
"//tensorflow/core/kernels:sparse_to_dense_op",
"//tensorflow/core/kernels:stack_ops",
"//tensorflow/core/kernels:training_ops",
] + if_mkl(
[
"//tensorflow/core/kernels:mkl_transpose_op",
],
[
"//tensorflow/core/kernels:transpose_op",
],
),
)
tf_kernel_library(

View File

@ -65,6 +65,6 @@ class XlaArgOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp);
};
REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes(), XlaArgOp);
REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes().CompilationOnly(), XlaArgOp);
} // namespace tensorflow

View File

@ -200,12 +200,23 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
}
bool resource_variable_seen = false;
for (int i = 0; i < ctx->num_inputs(); ++i) {
if (ctx->input_type(i) == DT_RESOURCE) {
resource_variable_seen = true;
} else {
OP_REQUIRES(
ctx, !resource_variable_seen,
errors::FailedPrecondition(
"Resource variables and regular inputs cannot be interleaved."));
}
}
xla::XlaOp outputs = xla::Conditional(
ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation,
xla::Tuple(b, inputs), *else_result.computation);
// Sets non-variable outputs.
for (int i = 0; i < output_types_.size(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
xla::XlaOp output_handle = xla::GetTupleElement(outputs, i);
if (VLOG_IS_ON(2)) {
LOG(INFO) << "Setting output " << i;
@ -219,7 +230,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
ctx->SetOutput(i, output_handle);
}
}
// Updates the values of any resource variables modified by the conditional
// bodies.
@ -247,6 +257,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp);
REGISTER_XLA_OP(Name("StatelessIf").AllowResourceTypes(), XlaIfOp);
REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp);
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/pooling.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
@ -71,59 +72,53 @@ class PoolingOp : public XlaOpKernel {
int num_dims() const { return num_spatial_dims_ + 2; }
// Method that builds an initial value to use in reductions.
virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0;
// The reduction operation to apply to each window.
virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0;
// A post-processing operation to apply on the outputs of the ReduceWindow.
virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
const xla::XlaOp& output, DataType dtype,
const TensorShape& input_shape) = 0;
void Compile(XlaOpKernelContext* ctx) override {
std::vector<int64> ksize = ksize_;
std::vector<int64> stride = stride_;
if (ctx->num_inputs() != 1) {
protected:
xla::StatusOr<std::vector<int64>> GetKernelSize(XlaOpKernelContext* ctx) {
if (ctx->num_inputs() == 1) {
return ksize_;
}
const TensorShape ksize_shape = ctx->InputShape(1);
// Validate input sizes.
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
errors::InvalidArgument("ksize must be a vector, not shape ",
ksize_shape.DebugString()));
OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(),
errors::InvalidArgument("Sliding window ksize field must "
if (!TensorShapeUtils::IsVector(ksize_shape)) {
return errors::InvalidArgument("ksize must be a vector, not shape ",
ksize_shape.DebugString());
}
if (ksize_shape.num_elements() != num_dims()) {
return errors::InvalidArgument(
"Sliding window ksize field must "
"specify ",
num_dims(), " dimensions"));
ksize.clear();
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize));
num_dims(), " dimensions");
}
std::vector<int64> ksize;
auto status = ctx->ConstantInputAsIntVector(1, &ksize);
if (!status.ok()) {
return status;
}
return ksize;
}
xla::StatusOr<std::vector<int64>> GetStride(XlaOpKernelContext* ctx) {
if (ctx->num_inputs() == 1) {
return stride_;
}
const TensorShape stride_shape = ctx->InputShape(2);
// Validate input sizes.
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
errors::InvalidArgument("stride must be a vector, not shape ",
stride_shape.DebugString()));
OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(),
errors::InvalidArgument("Sliding window stride field must "
"specify ",
num_dims(), " dimensions"));
stride.clear();
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride));
if (!TensorShapeUtils::IsVector(stride_shape)) {
return errors::InvalidArgument("stride must be a vector, not shape ",
stride_shape.DebugString());
}
const TensorShape input_shape = ctx->InputShape(0);
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
errors::InvalidArgument("Input to ", type_string(),
" operator must have ", num_dims(),
" dimensions"));
xla::XlaBuilder* const b = ctx->builder();
auto input =
XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_);
auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize,
stride, padding_);
auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0));
ctx->SetOutput(0,
PostProcessOutput(ctx, pooled, input_type(0), input_shape));
if (stride_shape.num_elements() != num_dims()) {
return errors::InvalidArgument(
"Sliding window stride field must "
"specify ",
num_dims(), " dimensions");
}
std::vector<int64> stride;
auto status = ctx->ConstantInputAsIntVector(2, &stride);
if (!status.ok()) {
return status;
}
return stride;
}
protected:
@ -136,24 +131,48 @@ class PoolingOp : public XlaOpKernel {
xla::PrimitiveType xla_reduction_type_;
};
// Converts the tensor data format to the one required by the XLA pooling
// library.
xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
int num_spatial_dims) {
int num_dims = num_spatial_dims + 2;
int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
gtl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
spatial_dimensions[spatial_dim] =
GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
}
return xla::TensorFormat(/*batch_dimension=*/batch_dimension,
/*feature_dimension=*/feature_dimension,
/*spatial_dimensions=*/spatial_dimensions);
}
class MaxPoolOp : public PoolingOp {
public:
MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
: PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
/*reduction_type=*/ctx->input_type(0)) {}
xla::XlaOp InitValue(xla::XlaBuilder* b) override {
return xla::MinValue(b, xla_reduction_type_);
}
void Compile(XlaOpKernelContext* ctx) override {
auto ksize_or_error = GetKernelSize(ctx);
OP_REQUIRES_OK(ctx, ksize_or_error.status());
std::vector<int64> ksize = ksize_or_error.ValueOrDie();
const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
return ctx->GetOrCreateMax(reduction_type_);
}
auto stride_or_error = GetStride(ctx);
OP_REQUIRES_OK(ctx, stride_or_error.status());
std::vector<int64> stride = stride_or_error.ValueOrDie();
xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
const xla::XlaOp& output, DataType dtype,
const TensorShape& input_shape) override {
return output;
const TensorShape input_shape = ctx->InputShape(0);
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
errors::InvalidArgument("Input to ", type_string(),
" operator must have ", num_dims(),
" dimensions"));
auto pooling =
xla::MaxPool(ctx->Input(0), ksize, stride, padding_,
XlaTensorFormat(data_format_, input_shape.dims() - 2));
ctx->SetOutput(0, pooling);
}
};
@ -180,9 +199,8 @@ class MaxPool3DOp : public MaxPoolOp {
};
REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
// Common computation shared between AvgPool and AvgPoolGrad. Divide each
// element of an image by the count of elements that contributed to that
// element during pooling.
// Divide each element of an image by the count of elements that contributed to
// that element during pooling.
static xla::XlaOp AvgPoolDivideByCount(
XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype,
const TensorShape& input_shape, xla::Padding padding,
@ -241,20 +259,34 @@ class AvgPoolOp : public PoolingOp {
/*reduction_type=*/
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitValue(xla::XlaBuilder* b) override {
return xla::Zero(b, xla_reduction_type_);
}
void Compile(XlaOpKernelContext* ctx) override {
auto ksize_or_error = GetKernelSize(ctx);
OP_REQUIRES_OK(ctx, ksize_or_error.status());
std::vector<int64> ksize = ksize_or_error.ValueOrDie();
const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
return ctx->GetOrCreateAdd(reduction_type_);
}
auto stride_or_error = GetStride(ctx);
OP_REQUIRES_OK(ctx, stride_or_error.status());
std::vector<int64> stride = stride_or_error.ValueOrDie();
xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
const xla::XlaOp& output, DataType dtype,
const TensorShape& input_shape) override {
return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_,
ksize_, stride_, num_spatial_dims_,
data_format_);
const TensorShape input_shape = ctx->InputShape(0);
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
errors::InvalidArgument("Input to ", type_string(),
" operator must have ", num_dims(),
" dimensions"));
auto xla_data_format =
XlaTensorFormat(data_format_, input_shape.dims() - 2);
auto spatial_padding = MakeSpatialPadding(
input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format);
// Convert the input to the reduction type.
auto converted_input =
ConvertElementType(ctx->Input(0), xla_reduction_type_);
auto pooling =
xla::AvgPool(converted_input, ksize, stride, spatial_padding,
xla_data_format, padding_ == xla::Padding::kValid);
// Convert the pooling result back to the input type before returning it.
ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0)));
}
};

View File

@ -104,7 +104,7 @@ class RetvalOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
};
REGISTER_XLA_OP(Name("_Retval"), RetvalOp);
REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp);
} // anonymous namespace
} // namespace tensorflow

View File

@ -38,11 +38,15 @@ class SoftmaxOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape logits_shape = ctx->InputShape(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
errors::InvalidArgument("logits must be 2-dimensional"));
OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(logits_shape),
errors::InvalidArgument("logits must have >= 1 dimension, got ",
logits_shape.DebugString()));
const int kBatchDim = 0;
const int kClassDim = 1;
// Major dimensions are batch dimensions, minor dimension is the class
// dimension.
std::vector<int64> batch_dims(logits_shape.dims() - 1);
std::iota(batch_dims.begin(), batch_dims.end(), 0);
const int kClassDim = logits_shape.dims() - 1;
const DataType type = input_type(0);
const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
@ -56,7 +60,7 @@ class SoftmaxOp : public XlaOpKernel {
xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim});
// Subtract the max in batch b from every element in batch b. Broadcasts
// along the batch dimension.
auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim});
auto shifted_logits = xla::Sub(logits, logits_max, batch_dims);
auto exp_shifted = xla::Exp(shifted_logits);
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
xla::PrimitiveType xla_accumulation_type;
@ -71,9 +75,9 @@ class SoftmaxOp : public XlaOpKernel {
auto softmax =
log_
// softmax = shifted_logits - log(sum(exp(shifted_logits)))
? xla::Sub(shifted_logits, xla::Log(sum), {kBatchDim})
? xla::Sub(shifted_logits, xla::Log(sum), batch_dims)
// softmax = exp(shifted_logits) / sum(exp(shifted_logits))
: xla::Div(exp_shifted, sum, {kBatchDim});
: xla::Div(exp_shifted, sum, batch_dims);
ctx->SetOutput(0, softmax);
}

View File

@ -301,6 +301,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
}
REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp);
REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp);
REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp);
} // namespace tensorflow

View File

@ -32,6 +32,23 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
return Status::OK();
}
Status HostTensorToMutableBorrowingLiteral(
Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor->dtype(),
host_tensor->shape(), &xla_shape));
return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal);
}
Status HostTensorToMutableBorrowingLiteral(
const xla::Shape& xla_shape, Tensor* host_tensor,
xla::MutableBorrowingLiteral* literal) {
*literal = xla::MutableBorrowingLiteral(
static_cast<const char*>(DMAHelper::base(host_tensor)), xla_shape);
return Status::OK();
}
Status HostTensorsToBorrowingLiteralTuple(
tensorflow::gtl::ArraySlice<Tensor> host_tensors,
xla::BorrowingLiteral* literal) {

View File

@ -30,6 +30,16 @@ namespace tensorflow {
// 'host_tensor'.
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
xla::BorrowingLiteral* literal);
// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer
// owned by 'host_tensor', but is mutable via the xla::Literal methods.
Status HostTensorToMutableBorrowingLiteral(
Tensor* host_tensor, xla::MutableBorrowingLiteral* literal);
// Similar as above, except the literal shape is explicitly provided and used
// instead of obtaining it from the 'host_tensor'. The provided literal shape
// 'xla_shape' must be compatible with the shape of 'host_tensor'.
Status HostTensorToMutableBorrowingLiteral(
const xla::Shape& xla_shape, Tensor* host_tensor,
xla::MutableBorrowingLiteral* literal);
// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers
// owned by 'host_tensors'.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include <queue>
#include <random>
#include <set>
#include <unordered_map>
@ -297,4 +298,29 @@ void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
}
}
namespace {
uint32 InitialRandomSeed() {
// Support plumbing the TF seed through to XLA is being worked on.
// If a user wants deterministic behavior, their best option
// is to start with a known checkpoint. This also handles issues when
// multiple random calls can be invoked in any order by TF executor.
// Another option is to use stateless random ops. They have much cleaner
// semantics.
// If a user really wants to set a deterministic seed for XLA-based
// devices, this is the place to do it.
std::random_device rd;
// Make the starting value odd.
return rd() | 1;
}
} // namespace
uint32 GetXLARandomSeed() {
// We initialize counter with an odd number and increment it by two
// everytime. This ensures that it will never be zero, even
// after an overflow. When seeded with zero, some XLA backends
// can return all zeros instead of random numbers.
static std::atomic<uint32> counter(InitialRandomSeed());
return counter.fetch_add(2);
}
} // namespace tensorflow

View File

@ -56,6 +56,9 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
KernelDef* kdef);
// Returns the next random seed to use for seeding xla rng.
uint32 GetXLARandomSeed();
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_

View File

@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include <cassert>
@ -22,61 +21,42 @@ namespace tensorflow {
XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
AllocMode alloc_mode)
: raw_function_(static_data.raw_function),
result_index_(static_data.result_index),
args_(new void*[static_data.num_args]),
temps_(new void*[static_data.num_temps]),
arg_index_to_temp_index_(new int32[static_data.num_args]),
num_args_(static_data.num_args),
arg_names_(static_data.arg_names),
result_names_(static_data.result_names),
program_shape_(static_data.program_shape),
hlo_profile_printer_data_(static_data.hlo_profile_printer_data) {
: raw_function_(static_data.raw_function_),
result_index_(static_data.result_index_),
buffer_table_(new void*[static_data.num_buffers_]),
buffer_infos_(static_data.buffer_infos_),
arg_index_table_(static_data.arg_index_table_),
num_args_(static_data.num_args_),
arg_names_(static_data.arg_names_),
result_names_(static_data.result_names_),
program_shape_(static_data.program_shape_),
hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) {
bool allocate_entry_params =
alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS;
// Allocate arg and temp buffers.
if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) {
alloc_args_ = cpu_function_runtime::MallocContiguousBuffers(
static_data.arg_sizes, static_data.num_args, args_,
/*annotate_initialized=*/false);
}
alloc_temps_ = cpu_function_runtime::MallocContiguousBuffers(
static_data.temp_sizes, static_data.num_temps, temps_,
alloc_buffer_table_ = cpu_function_runtime::MallocContiguousBuffers(
static_data.buffer_infos_, static_data.num_buffers_,
/*allocate_entry_params=*/allocate_entry_params, buffer_table_,
/*annotate_initialized=*/true);
for (int i = 0; i < static_data.num_temps; i++) {
if (static_data.temp_sizes[i] < -1) {
int32 param_number = -(static_data.temp_sizes[i] + 2);
arg_index_to_temp_index_[param_number] = i;
}
}
// If Hlo profiling is enabled the generated code expects an appropriately
// sized buffer to be passed in as the last argument. If Hlo profiling is
// disabled the last function argument is still present in the function
// signature, but it is ignored by the generated code and we pass in null for
// it.
if (hlo_profiling_enabled()) {
profile_counters_ = new int64[static_data.profile_counters_size]();
profile_counters_ = new int64[static_data.profile_counters_size_]();
}
}
bool XlaCompiledCpuFunction::Run() {
// Propagate pointers to the argument buffers into the temps array. Code
// generated by XLA discovers the incoming argument pointers from the temps
// array.
for (int32 i = 0; i < num_args_; i++) {
temps_[arg_index_to_temp_index_[i]] = args_[i];
}
raw_function_(temps_[result_index_], &run_options_, nullptr, temps_,
profile_counters_);
raw_function_(buffer_table_[result_index_], &run_options_, nullptr,
buffer_table_, profile_counters_);
return true;
}
XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
cpu_function_runtime::FreeContiguous(alloc_args_);
cpu_function_runtime::FreeContiguous(alloc_temps_);
delete[] args_;
delete[] temps_;
delete[] arg_index_to_temp_index_;
cpu_function_runtime::FreeContiguous(alloc_buffer_table_);
delete[] buffer_table_;
delete[] profile_counters_;
}

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <cassert>
#include <string>
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/core/platform/types.h"
@ -56,46 +57,85 @@ class XlaCompiledCpuFunction {
// StaticData represents the state necessary to run an XLA-compiled
// function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for
// AOT this is backed by data compiled into the object file.
struct StaticData {
//
// The contents of StaticData are XLA-internal implementation details and
// should not be relied on by clients.
//
// TODO(sanjoy): Come up with a cleaner way to express the contraint we want
// here: generated XlaCompiledCpuFunction subclasses should be able to create
// instances of StaticData but only XlaCompiledCpuFunction should be able to
// read from StaticData instances.
class StaticData {
public:
void set_raw_function(RawFunction raw_function) {
raw_function_ = raw_function;
}
void set_buffer_infos(
const cpu_function_runtime::BufferInfo* buffer_infos) {
buffer_infos_ = buffer_infos;
}
void set_num_buffers(size_t num_buffers) { num_buffers_ = num_buffers; }
void set_arg_index_table(const int32* arg_index_table) {
arg_index_table_ = arg_index_table;
}
void set_num_args(int64 num_args) { num_args_ = num_args; }
void set_result_index(size_t result_index) { result_index_ = result_index; }
void set_arg_names(const char** arg_names) { arg_names_ = arg_names; }
void set_result_names(const char** result_names) {
result_names_ = result_names;
}
void set_program_shape(const xla::ProgramShape* program_shape) {
program_shape_ = program_shape;
}
const xla::HloProfilePrinterData* hlo_profile_printer_data() const {
return hlo_profile_printer_data_;
}
void set_hlo_profile_printer_data(
const xla::HloProfilePrinterData* hlo_profile_printer_data) {
hlo_profile_printer_data_ = hlo_profile_printer_data;
}
void set_profile_counters_size(int64 profile_counters_size) {
profile_counters_size_ = profile_counters_size;
}
private:
// The raw function to call.
RawFunction raw_function;
RawFunction raw_function_;
// Cardinality and size of arg buffers.
const intptr_t* arg_sizes = nullptr;
size_t num_args = 0;
// Contains information about the buffers used by the XLA computation.
const cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr;
size_t num_buffers_ = 0;
// Cardinality and size of temp buffers.
//
// If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer.
//
// If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The
// corresponding entry in the temp buffer array needs to be set to null.
//
// If temp_sizes[i] < -1 then the i'th temp is the entry parameter
// -(temp_sizes[i] + 2).
const intptr_t* temp_sizes = nullptr;
size_t num_temps = 0;
// Entry parameter i is described by
// buffer_infos[arg_index_table[i]].
const int32* arg_index_table_ = nullptr;
// There are num_args entry parameters.
int64 num_args_ = 0;
// The 0-based index of the result tuple, in the temp buffers.
size_t result_index = 0;
size_t result_index_ = 0;
// [Optional] Arrays of arg and result names. These are arrays of C-style
// strings, where the array is terminated by nullptr.
const char** arg_names = nullptr;
const char** result_names = nullptr;
const char** arg_names_ = nullptr;
const char** result_names_ = nullptr;
// [Optional] Arg and result shapes.
const xla::ProgramShape* program_shape = nullptr;
const xla::ProgramShape* program_shape_ = nullptr;
// [Optional] Profile printer data. Null if profiling is disabled.
const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr;
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
// [Optional] The number of profile counters expected in the profile counter
// buffer by the generated code and hlo_profile_printer. 0 if profiling is
// disabled. This information is already present in
// hlo_profile_printer_data but xla::HloProfilePrinterData is forward
// declared so we don't have access to that information here.
int64 profile_counters_size = 0;
int64 profile_counters_size_ = 0;
// Only XlaCompiledCpuFunction is allowed to read the above fields.
friend class XlaCompiledCpuFunction;
};
// AllocMode controls the buffer allocation mode.
@ -135,14 +175,25 @@ class XlaCompiledCpuFunction {
// ------------------------------
// Arg methods for managing input buffers. Buffers are in row-major order.
// Returns the underlying array of argument buffers, where args()[I] is the
// buffer for the positional argument at index I.
void** args() { return args_; }
const void* const* args() const { return args_; }
// Returns the buffer for the positional argument at the given `index`.
void* arg_data(size_t index) { return args_[index]; }
const void* arg_data(size_t index) const { return args_[index]; }
void* arg_data(size_t index) {
return buffer_table_[arg_index_table_[index]];
}
const void* arg_data(size_t index) const {
return buffer_table_[arg_index_table_[index]];
}
int num_args() const { return num_args_; }
// Returns the size of entry parameter `idx`.
//
// There is a static version of this method on tfcompile generated subclasses
// of XlaCompiledCpuFunction, but try to prefer this when possible since it
// works both for XlaJitCompiledCpuFunction and AOT compiled subclasses.
int arg_size(int idx) const {
assert(idx < num_args());
return buffer_infos_[arg_index_table_[idx]].size();
}
// Sets the buffer for the positional argument at the given `index` to `data`.
// Must be called before Run to have an effect. May be called under any
@ -155,7 +206,9 @@ class XlaCompiledCpuFunction {
//
// Aliasing of argument and result buffers is not allowed, and results in
// undefined behavior.
void set_arg_data(size_t index, void* data) { args_[index] = data; }
void set_arg_data(size_t index, void* data) {
buffer_table_[arg_index_table_[index]] = data;
}
// ------------------------------
// Result methods for managing output buffers. Buffers are in row-major order.
@ -165,9 +218,9 @@ class XlaCompiledCpuFunction {
// Returns the underlying array of result buffers, where results()[I] is the
// buffer for the positional result at index I.
void** results() { return static_cast<void**>(temps_[result_index_]); }
void** results() { return static_cast<void**>(buffer_table_[result_index_]); }
const void* const* results() const {
return static_cast<const void* const*>(temps_[result_index_]);
return static_cast<const void* const*>(buffer_table_[result_index_]);
}
// Profile counters for this XLA computation.
@ -225,25 +278,28 @@ class XlaCompiledCpuFunction {
const RawFunction raw_function_;
const size_t result_index_;
// Arrays of argument and temp buffers; entries in args_ may be overwritten by
// the user.
void** args_ = nullptr;
void** temps_ = nullptr;
// Array containing pointers to argument and temp buffers (slots corresponding
// to constant and on-stack buffers are null).
void** const buffer_table_;
// Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for
// XLA generated code to be able to find it.
// Describes the buffers used by the XLA computation.
const cpu_function_runtime::BufferInfo* const buffer_infos_;
// Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]]
// for XLA generated code to be able to find it.
//
// For now we need to keep around the args_ array because there is code that
// depends on args() returning a void**. However, in the future we may remove
// args_ in favor of using temps_ as the sole storage for the arguments.
int32* arg_index_to_temp_index_;
// args_ in favor of using buffer_table_ as the sole storage for the
// arguments.
const int32* const arg_index_table_;
// The number of incoming arguments.
int32 num_args_;
const int32 num_args_;
// Backing memory for individual arg and temp buffers.
void* alloc_args_ = nullptr;
void* alloc_temps_ = nullptr;
// Backing memory for buffer_table_ and args_, the latter depending on
// AllocMode.
void* alloc_buffer_table_ = nullptr;
// Backing memory for profiling counters.
int64* profile_counters_ = nullptr;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -35,45 +36,6 @@ limitations under the License.
namespace tensorflow {
namespace {
// Returns a vector of positional argument buffer sizes.
xla::StatusOr<std::vector<intptr_t>> ComputeArgSizes(
const xla::ProgramShape& program_shape) {
std::vector<intptr_t> arg_sizes;
const size_t num_args = program_shape.parameters_size();
arg_sizes.reserve(num_args);
for (int i = 0; i < num_args; ++i) {
const xla::Shape& arg_shape = program_shape.parameters(i);
constexpr size_t kPointerSize = sizeof(void*);
arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize));
}
return std::move(arg_sizes);
}
// Returns a vector of positional temporary buffer sizes.
xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes(
const xla::BufferAssignment& buffer_assignment) {
const std::vector<xla::BufferAllocation>& allocations =
buffer_assignment.Allocations();
std::vector<intptr_t> temp_sizes;
temp_sizes.reserve(allocations.size());
for (const xla::BufferAllocation& allocation : allocations) {
if (allocation.is_constant() || allocation.is_thread_local()) {
// Constants are lowered to globals. Thread locals are lowered to
// allocas.
temp_sizes.push_back(-1);
} else if (allocation.is_entry_computation_parameter()) {
// Entry computation parameters need some preprocessing in
// XlaCompiledCpuFunction::Run. See the comment on
// XlaCompiledCpuFunction::StaticData::temp_sizes.
temp_sizes.push_back(-allocation.parameter_number() - 2);
} else {
temp_sizes.push_back(allocation.size());
}
}
return std::move(temp_sizes);
}
// Returns the index of the result in the temp buffers.
xla::StatusOr<size_t> ComputeResultIndex(
const xla::BufferAssignment& buffer_assignment) {
@ -157,11 +119,11 @@ XlaJitCompiledCpuFunction::Compile(
const xla::BufferAssignment& buffer_assignment =
cpu_executable->buffer_assignment();
// Compute buffer sizes and the result index, needed to run the raw function.
TF_ASSIGN_OR_RETURN(std::vector<intptr_t> arg_sizes,
ComputeArgSizes(*program_shape));
TF_ASSIGN_OR_RETURN(std::vector<intptr_t> temp_sizes,
ComputeTempSizes(buffer_assignment));
// Compute buffer infos and the result index, needed to run the raw function.
std::vector<cpu_function_runtime::BufferInfo> buffer_infos =
xla::cpu::CreateBufferInfosFromBufferAssignment(buffer_assignment);
std::vector<int32> arg_index_table =
xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
TF_ASSIGN_OR_RETURN(size_t result_index,
ComputeResultIndex(buffer_assignment));
@ -169,28 +131,28 @@ XlaJitCompiledCpuFunction::Compile(
new XlaJitCompiledCpuFunction);
XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get();
jit->executable_ = std::move(executable);
jit->arg_sizes_ = std::move(arg_sizes);
jit->temp_sizes_ = std::move(temp_sizes);
jit->buffer_infos_ = std::move(buffer_infos);
jit->arg_index_table_ = std::move(arg_index_table);
jit->program_shape_ = std::move(program_shape);
jit->static_data_.raw_function = std::move(raw_function);
jit->static_data_.arg_sizes = jit->arg_sizes_.data();
jit->static_data_.num_args = jit->arg_sizes_.size();
jit->static_data_.temp_sizes = jit->temp_sizes_.data();
jit->static_data_.num_temps = jit->temp_sizes_.size();
jit->static_data_.result_index = result_index;
jit->static_data_.set_raw_function(raw_function);
jit->static_data_.set_buffer_infos(jit->buffer_infos_.data());
jit->static_data_.set_num_buffers(jit->buffer_infos_.size());
jit->static_data_.set_arg_index_table(jit->arg_index_table_.data());
jit->static_data_.set_num_args(jit->arg_index_table_.size());
jit->static_data_.set_result_index(result_index);
// Optional metadata is collected and set below.
CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_);
CollectNames(config.fetch(), &jit->nonempty_result_names_,
&jit->result_names_);
jit->static_data_.arg_names = jit->arg_names_.data();
jit->static_data_.result_names = jit->result_names_.data();
jit->static_data_.program_shape = jit->program_shape_.get();
jit->static_data_.set_arg_names(jit->arg_names_.data());
jit->static_data_.set_result_names(jit->result_names_.data());
jit->static_data_.set_program_shape(jit->program_shape_.get());
if (cpu_executable->hlo_profiling_enabled()) {
jit->static_data_.hlo_profile_printer_data =
&cpu_executable->hlo_profile_printer_data();
jit->static_data_.profile_counters_size =
cpu_executable->hlo_profile_printer_data().profile_counters_size();
jit->static_data_.set_hlo_profile_printer_data(
&cpu_executable->hlo_profile_printer_data());
jit->static_data_.set_profile_counters_size(
cpu_executable->hlo_profile_printer_data().profile_counters_size());
}
return std::move(jit_unique_ptr);

View File

@ -66,9 +66,11 @@ class XlaJitCompiledCpuFunction {
// The static data is backed by the rest of the state in this class.
XlaCompiledCpuFunction::StaticData static_data_;
// The backing arrays of arg and temp buffer sizes.
std::vector<intptr_t> arg_sizes_;
std::vector<intptr_t> temp_sizes_;
// The backing array for buffer infos.
std::vector<cpu_function_runtime::BufferInfo> buffer_infos_;
// The backing array for the arg index table.
std::vector<int32> arg_index_table_;
// The backing arrays of arg and result names. We hold the actual strings in
// nonempty_*_names_, and hold arrays of pointers in *_names_ for the static

View File

@ -409,7 +409,7 @@ class Array {
// Returns the total number of elements in the array.
int64 num_elements() const {
return std::accumulate(sizes_.begin(), sizes_.end(), 1,
return std::accumulate(sizes_.begin(), sizes_.end(), 1LL,
std::multiplies<int64>());
}

View File

@ -121,6 +121,30 @@ xla_test(
],
)
cc_library(
name = "pooling",
srcs = ["pooling.cc"],
hdrs = ["pooling.h"],
deps = [
":arithmetic",
":constants",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:lib",
],
)
xla_test(
name = "pooling_test",
srcs = ["pooling_test.cc"],
deps = [
":pooling",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
cc_library(
name = "prng",
srcs = ["prng.cc"],
@ -144,7 +168,7 @@ cc_library(
":numeric",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client:xla_builder",
],
)
@ -161,7 +185,7 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],

View File

@ -0,0 +1,183 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/pooling.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
namespace xla {
namespace {
// Common computation shared between AvgPool and AvgPoolGrad. Divide each
// element of an image by the count of elements that contributed to that
// element during pooling.
XlaOp AvgPoolDivideByCountWithGeneralPadding(
XlaOp sums, PrimitiveType dtype,
tensorflow::gtl::ArraySlice<int64> input_shape,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
tensorflow::gtl::ArraySlice<int64> ksize,
tensorflow::gtl::ArraySlice<int64> stride,
const TensorFormat& data_format) {
// The padding shouldn't be included in the counts. We use another
// ReduceWindow to find the right counts.
const int num_spatial_dims = spatial_padding.size();
std::vector<int64> input_dim_sizes(num_spatial_dims);
std::vector<int64> window_dims(num_spatial_dims);
std::vector<int64> window_ksize(num_spatial_dims);
std::vector<int64> window_stride(num_spatial_dims);
CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
<< "Invalid number of spatial dimentions in data format specification";
for (int i = 0; i < num_spatial_dims; ++i) {
int dim = data_format.spatial_dimension(i);
input_dim_sizes[i] = input_shape[dim];
window_dims[i] = dim;
window_ksize[i] = ksize[dim];
window_stride[i] = stride[dim];
}
XlaBuilder* b = sums.builder();
// Build a matrix of all 1s, with the same width/height as the input.
auto ones = Broadcast(One(b, dtype), input_dim_sizes);
PaddingConfig padding_config;
for (int i = 0; i < num_spatial_dims; ++i) {
auto dims = padding_config.add_dimensions();
dims->set_edge_padding_low(spatial_padding[i].first);
dims->set_edge_padding_high(spatial_padding[i].second);
}
auto zero = Zero(b, dtype);
auto padded_ones = Pad(ones, zero, padding_config);
// Perform a ReduceWindow with the same window size, strides, and padding
// to count the number of contributions to each result element.
auto counts =
ReduceWindow(padded_ones, zero, CreateScalarAddComputation(dtype, b),
window_ksize, window_stride, Padding::kValid);
return Div(sums, counts, window_dims);
}
// Sums all elements in the window specified by 'kernel_size' and 'stride'.
XlaOp ComputeSums(XlaOp operand, XlaOp init_value,
tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride,
const TensorFormat& data_format) {
XlaBuilder* b = operand.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value));
PrimitiveType accumulation_type = init_shape.element_type();
auto add_computation = CreateScalarAddComputation(accumulation_type, b);
return ReduceWindow(operand, init_value, add_computation, kernel_size,
stride, Padding::kValid);
});
}
// Creates a padding configuration out of spatial padding values.
PaddingConfig MakeSpatialPaddingConfig(
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride,
const TensorFormat& data_format) {
const int num_spatial_dims = kernel_size.size() - 2;
PaddingConfig padding_config;
for (int i = 0; i < 2 + num_spatial_dims; ++i) {
padding_config.add_dimensions();
}
CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
<< "Invalid number of spatial dimentions in data format specification";
for (int i = 0; i < num_spatial_dims; ++i) {
int dim = data_format.spatial_dimension(i);
auto padding_dimension = padding_config.mutable_dimensions(dim);
padding_dimension->set_edge_padding_low(spatial_padding[i].first);
padding_dimension->set_edge_padding_high(spatial_padding[i].second);
}
return padding_config;
}
} // namespace
XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
const TensorFormat& data_format) {
XlaBuilder* b = operand.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
PrimitiveType dtype = operand_shape.element_type();
auto max_computation = CreateScalarMaxComputation(dtype, b);
auto init_value = MinValue(b, dtype);
return ReduceWindow(operand, init_value, max_computation, kernel_size,
stride, padding);
});
}
XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const TensorFormat& data_format,
const bool counts_include_padding) {
XlaBuilder* b = operand.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
PrimitiveType dtype = operand_shape.element_type();
auto init_value = Zero(b, dtype);
std::vector<int64> input_size(operand_shape.dimensions().begin(),
operand_shape.dimensions().end());
auto padding_config =
MakeSpatialPaddingConfig(padding, kernel_size, stride, data_format);
auto padded_operand = Pad(operand, Zero(b, dtype), padding_config);
auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride,
data_format);
if (counts_include_padding) {
// If counts include padding, all windows have the same number of elements
// contributing to each average. Divide by the window size everywhere to
// get the average.
int64 window_size =
std::accumulate(kernel_size.begin(), kernel_size.end(), 1,
[](int64 x, int64 y) { return x * y; });
auto divisor = ConstantR0WithType(b, dtype, window_size);
return pooled / divisor;
} else {
return AvgPoolDivideByCountWithGeneralPadding(
pooled, dtype, input_size, padding, kernel_size, stride, data_format);
}
});
}
std::vector<std::pair<int64, int64>> MakeSpatialPadding(
tensorflow::gtl::ArraySlice<int64> input_size,
tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
const TensorFormat& data_format) {
const int num_spatial_dims = kernel_size.size() - 2;
std::vector<int64> input_spatial_dimensions;
std::vector<int64> kernel_size_spatial_dimensions;
std::vector<int64> stride_spatial_dimensions;
CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
<< "Invalid number of spatial dimentions in data format specification";
for (int i = 0; i < num_spatial_dims; ++i) {
int dim = data_format.spatial_dimension(i);
input_spatial_dimensions.push_back(input_size[dim]);
kernel_size_spatial_dimensions.push_back(kernel_size[dim]);
stride_spatial_dimensions.push_back(stride[dim]);
}
return MakePadding(input_spatial_dimensions, kernel_size_spatial_dimensions,
stride_spatial_dimensions, padding);
}
} // namespace xla

View File

@ -0,0 +1,73 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace xla {
// Tensor format for reduce window operations.
class TensorFormat {
public:
TensorFormat(int batch_dimension, int feature_dimension,
tensorflow::gtl::ArraySlice<int64> spatial_dimensions)
: batch_dimension_(batch_dimension),
feature_dimension_(feature_dimension),
spatial_dimensions_(spatial_dimensions.begin(),
spatial_dimensions.end()) {}
int batch_dimension() const { return batch_dimension_; }
int feature_dimension() const { return feature_dimension_; }
int spatial_dimension(int dim) const { return spatial_dimensions_[dim]; }
int num_spatial_dims() const { return spatial_dimensions_.size(); }
private:
// The number of the dimension that represents the batch.
int batch_dimension_;
// The number of the dimension that represents the features.
int feature_dimension_;
// The dimension numbers for the spatial dimensions.
tensorflow::gtl::InlinedVector<int, 4> spatial_dimensions_;
};
// Computes the max pool of 'operand'.
XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
const TensorFormat& data_format);
// Computes the average pool of 'operand'.
XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const TensorFormat& data_format,
const bool counts_include_padding);
// Returns the list of low and high padding elements in each spatial dimension
// for the given 'padding' specification.
std::vector<std::pair<int64, int64>> MakeSpatialPadding(
tensorflow::gtl::ArraySlice<int64> input_size,
tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
const TensorFormat& data_format);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_

View File

@ -0,0 +1,185 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/pooling.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
namespace xla {
namespace {
TensorFormat MakeNCHWFormat(int num_spatial_dims) {
tensorflow::gtl::InlinedVector<int64, 4> spatial_dimensions;
for (int i = 0; i < num_spatial_dims; ++i) {
spatial_dimensions.push_back(i + 2);
}
return TensorFormat(/*batch_dimension=*/0, /*feature_dimension=*/1,
/*spatial_dimensions=*/spatial_dimensions);
}
std::vector<std::pair<int64, int64>> MakeGeneralPadding(
XlaOp input, tensorflow::gtl::ArraySlice<int64> kernel_size,
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
const xla::TensorFormat& data_format) {
XlaBuilder* b = input.builder();
Shape operand_shape = b->GetShape(input).ValueOrDie();
std::vector<int64> input_size(operand_shape.dimensions().begin(),
operand_shape.dimensions().end());
return MakeSpatialPadding(input_size, kernel_size, stride, padding,
data_format);
}
// Add singleton batch and feature dimensions to spatial dimensions, according
// to 'data_format' specification.
std::vector<int64> ExpandWithBatchAndFeatureDimensions(
tensorflow::gtl::ArraySlice<int64> spatial_dim_sizes,
const xla::TensorFormat& data_format) {
const int num_spatial_dims = spatial_dim_sizes.size();
std::vector<int64> tensor_sizes(num_spatial_dims + 2, 1);
for (int i = 0; i < num_spatial_dims; ++i) {
int dim = data_format.spatial_dimension(i);
tensor_sizes[dim] = spatial_dim_sizes[i];
}
return tensor_sizes;
}
class PoolingTest : public ClientLibraryTestBase {
public:
ErrorSpec error_spec_{0.0001};
};
XLA_TEST_F(PoolingTest, MaxPool2D) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
auto stride = kernel_size;
MaxPool(input, kernel_size, stride, Padding::kValid, data_format);
ComputeAndCompareR4<float>(&builder, {{{{5, 4}}}}, {}, error_spec_);
}
XLA_TEST_F(PoolingTest, MaxPool2DWithPadding) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
auto stride = kernel_size;
MaxPool(input, kernel_size, stride, Padding::kSame, data_format);
ComputeAndCompareR4<float>(&builder, {{{{5, 4, 5}}}}, {}, error_spec_);
}
XLA_TEST_F(PoolingTest, MaxPool2DWithPaddingAndStride) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
MaxPool(input, kernel_size, stride, Padding::kSame, data_format);
ComputeAndCompareR4<float>(&builder, {{{{5, 4, 4, 5, 5}, {5, 4, 3, 2, 1}}}},
{}, error_spec_);
}
XLA_TEST_F(PoolingTest, AvgPool2D) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
auto stride = kernel_size;
auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kValid,
data_format);
AvgPool(input, kernel_size, stride, padding, data_format,
/*counts_include_padding=*/true);
ComputeAndCompareR4<float>(&builder, {{{{3, 3}}}}, {}, error_spec_);
}
XLA_TEST_F(PoolingTest, AvgPool2DWithPadding) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
auto stride = kernel_size;
auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame,
data_format);
AvgPool(input, kernel_size, stride, padding, data_format,
/*counts_include_padding=*/false);
ComputeAndCompareR4<float>(&builder, {{{{3, 3, 3}}}}, {}, error_spec_);
}
XLA_TEST_F(PoolingTest, AvgPool2DWithPaddingAndStride) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame,
data_format);
AvgPool(input, kernel_size, stride, padding, data_format,
/*counts_include_padding=*/false);
ComputeAndCompareR4<float>(&builder,
{{{{3, 3, 3, 3, 3}, {4.5, 3.5, 2.5, 1.5, 1}}}}, {},
error_spec_);
}
XLA_TEST_F(PoolingTest, AvgPool2DWithGeneralPaddingCountNotIncludePadding) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format);
auto stride = kernel_size;
AvgPool(input, kernel_size, stride, {{1, 1}, {2, 1}}, data_format,
/*counts_include_padding=*/false);
ComputeAndCompareR4<float>(&builder, {{{{3, 3}}}}, {}, error_spec_);
}
XLA_TEST_F(PoolingTest,
AvgPool2DWithGeneralPaddingCountNotIncludePaddingAndStride) {
XlaBuilder builder(TestName());
XlaOp input = ConstantR4FromArray4D<float>(
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
auto data_format = MakeNCHWFormat(2);
auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format);
auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
AvgPool(input, kernel_size, stride, {{2, 1}, {1, 1}}, data_format,
/*counts_include_padding=*/false);
ComputeAndCompareR4<float>(&builder, {{{{1.5, 3, 4.5}, {3, 3, 3}}}}, {},
error_spec_);
}
} // namespace
} // namespace xla

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"

View File

@ -303,7 +303,7 @@ StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
const Shape& shape, int device_ordinal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device_ordinal));
auto literal = MakeUnique<Literal>();
auto literal = Literal::CreateFromShape(shape);
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
executor, shape, literal.get()));
return std::move(literal);

View File

@ -45,21 +45,6 @@ int64 GetUniqueId() {
return id;
}
// Returns true if an instruction with the given opcode can be the root of the
// computation.
bool CanBeRoot(HloOpcode opcode) {
switch (opcode) {
case HloOpcode::kAfterAll:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
return false;
default:
return true;
}
}
} // namespace
XlaOp operator-(const XlaOp& x) { return Neg(x); }
@ -142,28 +127,13 @@ XlaOp XlaBuilder::ReportErrorOrReturn(
return ReportErrorOrReturn(op_creator());
}
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const {
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
TF_RETURN_IF_ERROR(first_error_);
TF_RET_CHECK(root_id != nullptr);
TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size()));
ProgramShape program_shape;
// Not all instructions can be roots. Walk backwards from the last added
// instruction until a valid root is found.
int64 index = instructions_.size() - 1;
for (; index >= 0; index--) {
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
StringToHloOpcode(instructions_[index].opcode()));
if (CanBeRoot(opcode)) {
break;
}
}
if (index < 0) {
return FailedPrecondition("no root instruction was found");
}
*root_id = instructions_[index].id();
*program_shape.mutable_result() = instructions_[index].shape();
*program_shape.mutable_result() = instructions_[root_id].shape();
// Check that the parameter numbers are continuous from 0, and add parameter
// shapes and names to the program shape.
@ -188,8 +158,15 @@ StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const {
}
StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
int64 root;
return GetProgramShape(&root);
TF_RET_CHECK(!instructions_.empty());
return GetProgramShape(instructions_.back().id());
}
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
if (root.builder_ != this) {
return InvalidArgument("Given root operation is not in this computation.");
}
return GetProgramShape(root.handle());
}
void XlaBuilder::IsConstantVisitor(const int64 op_handle,
@ -257,17 +234,29 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
return AppendStatus(first_error_, backtrace);
}
return Build(instructions_.back().id());
}
StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root) {
if (root.builder_ != this) {
return InvalidArgument("Given root operation is not in this computation.");
}
return Build(root.handle());
}
StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
if (!first_error_.ok()) {
string backtrace;
first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
return AppendStatus(first_error_, backtrace);
}
HloComputationProto entry;
entry.set_id(GetUniqueId()); // Give the computation a global unique id.
entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique.
{
int64 root_id;
TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(),
GetProgramShape(&root_id));
TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id));
entry.set_root_id(root_id);
}
for (auto& instruction : instructions_) {
// Ensures that the instruction names are unique among the whole graph.
@ -1099,11 +1088,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
sharding_builder::AssignDevice(0);
XlaScopedShardingAssignment scoped_sharding(this,
infeed_instruction_sharding);
TF_ASSIGN_OR_RETURN(infeed,
AddInstruction(std::move(instr), HloOpcode::kInfeed));
TF_ASSIGN_OR_RETURN(
infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {}));
} else {
TF_ASSIGN_OR_RETURN(infeed,
AddInstruction(std::move(instr), HloOpcode::kInfeed));
TF_ASSIGN_OR_RETURN(
infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {}));
}
// The infeed instruction produces a tuple of the infed data and a token
@ -1892,6 +1881,61 @@ XlaOp XlaBuilder::CrossReplicaSum(
});
}
XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
// The HloInstruction for Alltoall currently only handles the data
// communication: it accepts N already split parts and scatters them to N
// cores, and each core gathers the N received parts into a tuple as the
// output. So here we explicitly split the operand before the hlo alltoall,
// and concat the tuple elements.
//
// First, run shape inference to make sure the shapes are valid.
TF_RETURN_IF_ERROR(
ShapeInference::InferAllToAllShape(operand_shape, split_dimension,
concat_dimension, split_count)
.status());
// Split into N parts.
std::vector<XlaOp> slices;
slices.reserve(split_count);
const int64 block_size =
operand_shape.dimensions(split_dimension) / split_count;
for (int i = 0; i < split_count; i++) {
slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
/*limit_index=*/(i + 1) * block_size,
/*stride=*/1, /*dimno=*/split_dimension));
}
// Handle data communication.
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
std::vector<const Shape*> slice_shape_ptrs;
c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
}
TF_ASSIGN_OR_RETURN(
XlaOp alltoall,
AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices));
// Concat the N received parts.
std::vector<XlaOp> received;
received.reserve(split_count);
for (int i = 0; i < split_count; i++) {
received.push_back(this->GetTupleElement(alltoall, i));
}
return this->ConcatInDim(received, concat_dimension);
});
}
XlaOp XlaBuilder::SelectAndScatter(
const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
@ -2163,11 +2207,6 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
LookUpInstruction(root_op));
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
if (!CanBeRoot(opcode)) {
return InvalidArgument("the operand with opcode %s cannot be root",
root->opcode().c_str());
}
HloComputationProto entry;
entry.set_id(GetUniqueId()); // Give the computation a global unique id.
@ -2693,6 +2732,13 @@ XlaOp CrossReplicaSum(
replica_group_ids, channel_id);
}
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups) {
return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
split_count, replica_groups);
}
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,

View File

@ -195,9 +195,14 @@ class XlaBuilder {
// Builds the computation with the requested operations, or returns a non-ok
// status. Note that all ops that have been enqueued will be moved to the
// computation being returned.
// computation being returned. The root of the computation will be the last
// added operation.
StatusOr<XlaComputation> Build();
// Overload of Build which specifies a particular root instruction for the
// computation.
StatusOr<XlaComputation> Build(XlaOp root);
// Builds the computation with the requested operations, or notes an error in
// the parent XlaBuilder and returns an empty computation if building failed.
// This function is intended to be used where the returned XlaComputation is
@ -225,9 +230,14 @@ class XlaBuilder {
// Returns the shape of the given op.
StatusOr<Shape> GetShape(const XlaOp& op) const;
// Returns the (inferred) result for the current computation's shape.
// Returns the (inferred) result for the current computation's shape. This
// assumes the root instruction is the last added instruction.
StatusOr<ProgramShape> GetProgramShape() const;
// Returns the (inferred) result for the current computation's shape using the
// given operation as the root.
StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
// Reports an error to the builder, by
// * storing it internally and capturing a backtrace if it's the first error
// (this deferred value will be produced on the call to
@ -255,6 +265,9 @@ class XlaBuilder {
StatusOr<bool> IsConstant(const XlaOp& operand) const;
private:
// Build helper which takes the id of the root operation..
StatusOr<XlaComputation> Build(int64 root_id);
// Enqueues a "retrieve parameter value" instruction for a parameter that was
// passed to the computation.
XlaOp Parameter(int64 parameter_number, const Shape& shape,
@ -686,9 +699,9 @@ class XlaBuilder {
// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
//
// - `channel_id`: for Allreduce nodes from different models, if they have the
// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
// applied cross models.
// - `channel_id`: for Allreduce nodes from different modules, if they have
// the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will
// not be applied cross modules.
//
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
XlaOp CrossReplicaSum(
@ -697,6 +710,13 @@ class XlaBuilder {
const tensorflow::gtl::optional<ChannelHandle>& channel_id =
tensorflow::gtl::nullopt);
// Enqueues an operation that do an Alltoall of the operand cross cores.
//
// TODO(b/110096724): This is NOT YET ready to use.
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
@ -969,9 +989,8 @@ class XlaBuilder {
// shape.
StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand);
// Returns the (inferred) result for the program shape for the current
// computation and fills the root_id in the pointer.
StatusOr<ProgramShape> GetProgramShape(int64* root_id) const;
// Returns the (inferred) result for the program shape using the given root.
StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
// Returns shapes for the operands.
StatusOr<std::vector<Shape>> GetOperandShapes(
@ -1234,6 +1253,9 @@ class XlaBuilder {
const XlaOp& operand, const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
const tensorflow::gtl::optional<ChannelHandle>& channel_id);
friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
friend XlaOp SelectAndScatter(
const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
@ -1820,9 +1842,9 @@ XlaOp CrossReplicaSum(
// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
//
// - `channel_id`: for Allreduce nodes from different models, if they have the
// - `channel_id`: for Allreduce nodes from different modules, if they have the
// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
// applied cross models.
// applied cross modules.
//
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
@ -1830,6 +1852,13 @@ XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
const tensorflow::gtl::optional<ChannelHandle>&
channel_id = tensorflow::gtl::nullopt);
// Enqueues an operation that do an Alltoall of the operand cross cores.
//
// TODO(b/110096724): This is NOT YET ready to use.
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups = {});
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@ -46,6 +47,17 @@ class XlaBuilderTest : public ::testing::Test {
return HloModule::CreateFromProto(proto, config);
}
// Overload which explicitly specifies the root instruction.
StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b,
XlaOp root) {
TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build(root));
const HloModuleProto& proto = computation.proto();
TF_ASSIGN_OR_RETURN(const auto& config,
HloModule::CreateModuleConfigFromProto(
proto, legacy_flags::GetDebugOptionsFromFlags()));
return HloModule::CreateFromProto(proto, config);
}
// Returns the name of the test currently being run.
string TestName() const {
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
@ -293,6 +305,21 @@ TEST_F(XlaBuilderTest, Transpose) {
EXPECT_THAT(root, op::Transpose(op::Parameter()));
}
TEST_F(XlaBuilderTest, AllToAll) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0,
/*split_count=*/2);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
// AllToAll is decomposed into slices -> all-to-all -> gte -> concat.
EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate);
EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll);
EXPECT_TRUE(
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
}
TEST_F(XlaBuilderTest, ReportError) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
@ -320,5 +347,45 @@ TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) {
EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
}
TEST_F(XlaBuilderTest, BuildWithSpecificRoot) {
XlaBuilder b(TestName());
XlaOp constant = ConstantR0<float>(&b, 1.0);
Add(constant, ConstantR0<float>(&b, 2.0));
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Constant());
}
TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) {
// Specifying a particular root in Build should still include all entry
// parameters.
XlaBuilder b(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
XlaOp x = Parameter(&b, 0, shape, "x");
XlaOp y = Parameter(&b, 1, shape, "y");
XlaOp z = Parameter(&b, 2, shape, "z");
Add(x, Sub(y, z));
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Parameter());
EXPECT_EQ(module->entry_computation()->num_parameters(), 3);
EXPECT_EQ(module->entry_computation()->instruction_count(), 5);
}
TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) {
XlaBuilder b(TestName());
XlaBuilder other_b(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
Parameter(&b, 0, shape, "param");
XlaOp other_param = Parameter(&other_b, 0, shape, "other_param");
Status status = b.Build(other_param).status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(
status.error_message(),
::testing::HasSubstr("root operation is not in this computation"));
}
} // namespace
} // namespace xla

View File

@ -1,33 +0,0 @@
# Description:
# The new XLA client libraries.
licenses(["notice"]) # Apache 2.0
package(default_visibility = [":friends"])
package_group(
name = "friends",
includes = [
"//tensorflow/compiler/xla:friends",
],
)
# Filegroup used to collect source files for dependency checking.
filegroup(
name = "c_srcs",
data = glob([
"**/*.cc",
"**/*.h",
]),
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
cc_library(
name = "xla_builder",
hdrs = ["xla_builder.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla/client:xla_builder",
],
)

View File

@ -71,7 +71,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) {
return out;
}
Literal::StrideConfig::StrideConfig(
MutableLiteralBase::StrideConfig::StrideConfig(
const Shape& source_shape, const Shape& dest_shape,
tensorflow::gtl::ArraySlice<int64> dimensions)
: dimensions(dimensions),
@ -133,7 +133,8 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
}
Literal::Literal(const Shape& shape, bool allocate_arrays)
: LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
: MutableLiteralBase() {
shape_ = MakeUnique<Shape>(shape);
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
@ -159,7 +160,9 @@ void Literal::DeallocateBuffers() {
});
}
Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); }
Literal::Literal(Literal&& other) : MutableLiteralBase() {
*this = std::move(other);
}
Literal& Literal::operator=(Literal&& other) {
DCHECK(&other.root_piece_->subshape() == other.shape_.get());
@ -187,12 +190,13 @@ const SparseIndexArray* LiteralBase::sparse_indices(
return piece(shape_index).sparse_indices();
}
SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
SparseIndexArray* MutableLiteralBase::sparse_indices(
const ShapeIndex& shape_index) {
return piece(shape_index).sparse_indices();
}
template <typename NativeT>
Status Literal::CopySliceFromInternal(
Status MutableLiteralBase::CopySliceFromInternal(
const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size) {
@ -225,7 +229,7 @@ Status Literal::CopySliceFromInternal(
// proper stride size at the matching dimension.
DimensionVector src_indexes(src_base.size(), 0);
DimensionVector dest_indexes(dest_base.size(), 0);
Literal::StrideConfig stride_config(src_literal.shape(), shape(),
MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
copy_size);
auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
@ -253,7 +257,8 @@ Status Literal::CopySliceFromInternal(
return Status::OK();
}
Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
Status MutableLiteralBase::CopyElementFrom(
const LiteralSlice& src_literal,
tensorflow::gtl::ArraySlice<int64> src_index,
tensorflow::gtl::ArraySlice<int64> dest_index) {
DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
@ -275,8 +280,8 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
return Status::OK();
}
/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
const LiteralProto& proto) {
/* static */ StatusOr<std::unique_ptr<Literal>>
MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
if (!proto.has_shape()) {
return InvalidArgument("LiteralProto has no shape");
}
@ -405,7 +410,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
return Status::OK();
}
Status Literal::CopyFrom(const LiteralSlice& src_literal,
Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
const ShapeIndex& dest_shape_index,
const ShapeIndex& src_shape_index) {
const Shape& dest_subshape =
@ -482,7 +487,8 @@ Status Literal::MoveFrom(Literal&& src_literal,
return Status::OK();
}
Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
Status MutableLiteralBase::CopySliceFrom(
const LiteralSlice& src_literal,
tensorflow::gtl::ArraySlice<int64> src_base,
tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size) {
@ -543,7 +549,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
shape().element_type());
}
void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(element_count(), values.bits());
@ -895,8 +901,8 @@ size_t LiteralBase::Hash() const {
return hash_value;
}
Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
int64 value) {
Status MutableLiteralBase::SetIntegralAsS64(
tensorflow::gtl::ArraySlice<int64> multi_index, int64 value) {
CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
case PRED:
@ -933,7 +939,7 @@ tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
return p.sparse_indices()->At(sparse_element_number);
}
void Literal::SortSparseElements(const ShapeIndex& shape_index) {
void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) {
piece(shape_index).SortSparseElements();
}
@ -1391,11 +1397,11 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
elements.push_back(std::move(*new_element));
}
auto converted = MakeUnique<Literal>();
*converted = Literal::MoveIntoTuple(&elements);
*converted = MutableLiteralBase::MoveIntoTuple(&elements);
return std::move(converted);
}
/* static */ Literal Literal::MoveIntoTuple(
/* static */ Literal MutableLiteralBase::MoveIntoTuple(
tensorflow::gtl::MutableArraySlice<Literal> elements) {
std::vector<Shape> element_shapes;
for (const Literal& element : elements) {
@ -1808,7 +1814,8 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
} // namespace
Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
// These conditions should have been checked in Literal::CreateFromProto.
// These conditions should have been checked in
// MutableLiteralBase::CreateFromProto.
TF_RET_CHECK(proto.has_shape());
TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
@ -1900,7 +1907,7 @@ const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
return piece(shape_index).untyped_data();
}
void* Literal::untyped_data(const ShapeIndex& shape_index) {
void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
return piece(shape_index).untyped_data();
}
@ -1916,6 +1923,127 @@ string LiteralBase::GetR1U8AsString() const {
ShapeUtil::ElementsIn(shape()));
}
void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
Piece* src_piece,
Piece* dest_piece) {
DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
<< "src_piece has shape: "
<< ShapeUtil::HumanString(src_piece->subshape())
<< "dest_piece has shape: "
<< ShapeUtil::HumanString(dest_piece->subshape());
if (ShapeUtil::IsTuple(shape)) {
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
const Shape& subshape = shape.tuple_shapes(i);
auto child_piece = Piece();
child_piece.set_subshape(&subshape);
CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
dest_piece->emplace_back(std::move(child_piece));
}
} else if (ShapeUtil::IsArray(shape)) {
dest_piece->set_buffer(src_piece->buffer());
} else {
// If the shape is neither an array nor tuple, then it must be
// zero-sized. Otherwise, some memory needs to be allocated for it.
CHECK_EQ(dest_piece->size_bytes(), 0);
}
}
MutableLiteralBase::~MutableLiteralBase() {}
MutableBorrowingLiteral::MutableBorrowingLiteral(
const MutableBorrowingLiteral& literal)
: MutableLiteralBase() {
shape_ = MakeUnique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
}
MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
const MutableBorrowingLiteral& literal) {
shape_ = MakeUnique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
return *this;
}
MutableBorrowingLiteral::MutableBorrowingLiteral(
const MutableLiteralBase& literal)
: MutableLiteralBase() {
shape_ = MakeUnique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
}
MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
: MutableLiteralBase() {
shape_ = MakeUnique<Shape>(literal->shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
}
MutableBorrowingLiteral::MutableBorrowingLiteral(
MutableBorrowingLiteral literal, const ShapeIndex& view_root)
: MutableLiteralBase() {
shape_ = MakeUnique<Shape>(literal.piece(view_root).subshape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
}
MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
const Shape& shape)
: MutableLiteralBase() {
shape_ = MakeUnique<Shape>(shape);
CHECK(LayoutUtil::HasLayout(*shape_));
CHECK(!ShapeUtil::IsTuple(*shape_));
root_piece_ = new Piece();
root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
root_piece_->set_subshape(shape_.get());
}
MutableBorrowingLiteral::~MutableBorrowingLiteral() {
if (root_piece_ != nullptr) {
root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (piece->buffer() != nullptr) {
delete piece->sparse_indices();
}
});
delete root_piece_;
}
}
LiteralSlice::LiteralSlice(const LiteralBase& literal)
: LiteralBase(), root_piece_(&literal.root_piece()) {}
LiteralSlice::LiteralSlice(const LiteralBase& literal,
const ShapeIndex& view_root)
: LiteralBase(), root_piece_(&literal.piece(view_root)) {}
void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
CHECK(ShapeUtil::IsTuple(shape));
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
@ -1932,13 +2060,6 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
}
}
LiteralSlice::LiteralSlice(const LiteralBase& literal)
: LiteralBase(), root_piece_(&literal.root_piece()) {}
LiteralSlice::LiteralSlice(const LiteralBase& literal,
const ShapeIndex& view_root)
: LiteralBase(), root_piece_(&literal.piece(view_root)) {}
BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
: LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
CHECK(ShapeUtil::IsArray(*shape_));

View File

@ -310,9 +310,10 @@ class LiteralBase {
// type of literal itself (0 for numeric types, and false for predicates).
//
// Note: It's an antipattern to use this method then immediately call
// Literal::Populate on the result (since that results in zero initialization,
// then reinitialization. Conside if a call to MakeUnique<Literal>(shape),
// followed by the call to Literal::Populate can be used instead.
// MutableLiteralBase::Populate on the result (since that results in zero
// initialization, then reinitialization. Conside if a call to
// MakeUnique<Literal>(shape), followed by the call to
// MutableLiteralBase::Populate can be used instead.
static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
protected:
@ -534,7 +535,7 @@ class LiteralBase {
virtual const Piece& root_piece() const = 0;
// LiteralSlice and Literal must access Pieces of other Literals.
friend class Literal;
friend class MutableLiteralBase;
friend class LiteralSlice;
friend class BorrowingLiteral;
@ -545,33 +546,10 @@ class LiteralBase {
tensorflow::gtl::ArraySlice<int64> start_indices) const;
};
// Class representing literal values in XLA.
//
// The underlying buffer and shape is always owned by this class.
class Literal : public LiteralBase {
// Abstract base class representing a mutable literal in XLA.
class MutableLiteralBase : public LiteralBase {
public:
Literal() : Literal(ShapeUtil::MakeNil()) {}
// Create a literal of the given shape. The literal is allocated sufficient
// memory to hold the shape. Memory is uninitialized.
explicit Literal(const Shape& shape);
virtual ~Literal();
// Literals are moveable, but not copyable. To copy a literal use
// Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
// of literals which can be expensive.
Literal(const Literal& other) = delete;
Literal& operator=(const Literal& other) = delete;
Literal(Literal&& other);
// 'allocate_arrays' indicates whether to allocate memory for the arrays in
// the shape. If false, buffer pointers inside of the Literal::Pieces are set
// to nullptr.
Literal(const Shape& shape, bool allocate_arrays);
Literal& operator=(Literal&& other);
// TODO(b/67651157): Remove this accessor. Literal users should not be able to
// mutate the shape as this can produce malformed Literals.
Shape* mutable_shape_do_not_use() { return shape_.get(); }
virtual ~MutableLiteralBase() = 0;
// Returns a MutableArraySlice view of the array for this literal for the
// given NativeT (e.g., float). CHECKs if the subshape of the literal at the
@ -587,6 +565,10 @@ class Literal : public LiteralBase {
// is not a sparse array.
SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
// TODO(b/67651157): Remove this accessor. Literal users should not be able to
// mutate the shape as this can produce malformed Literals.
Shape* mutable_shape_do_not_use() { return shape_.get(); }
// Returns a pointer to the underlying buffer holding the array at the given
// shape index. CHECKs if the subshape of the literal at the given ShapeIndex
// is not array.
@ -613,21 +595,6 @@ class Literal : public LiteralBase {
const ShapeIndex& dest_shape_index = {},
const ShapeIndex& src_shape_index = {});
// Returns a vector containing the tuple elements of this Literal as separate
// Literals. This Literal must be tuple-shaped and can be a nested tuple. The
// elements are moved into the new Literals; no data is copied. Upon return
// this Literal is set to a nil shape (empty tuple)
std::vector<Literal> DecomposeTuple();
// Similar to CopyFrom, but with move semantincs. The subshape of this literal
// rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
// (layouts and shapes must match), but need not be arrays. The memory
// allocated in this literal for the subshape at dest_shape_index is
// deallocated, and the respective buffers are replaced with those in
// src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
Status MoveFrom(Literal&& src_literal,
const ShapeIndex& dest_shape_index = {});
// Copies the values from src_literal, starting at src_base shape indexes,
// to this literal, starting at dest_base, where the copy size in each
// dimension is specified by copy_size.
@ -730,12 +697,7 @@ class Literal : public LiteralBase {
static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
const LiteralProto& proto);
private:
// Recursively sets the subshapes and buffers of all subpieces rooted at
// 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
// the shape.
void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
protected:
// Returns the piece at the given ShapeIndex.
Piece& piece(const ShapeIndex& shape_index) {
return const_cast<Piece&>(LiteralBase::piece(shape_index));
@ -783,12 +745,83 @@ class Literal : public LiteralBase {
template <typename NativeT, typename FnType>
Status PopulateInternal(const FnType& generator, bool parallel);
friend class LiteralBase;
friend class MutableBorrowingLiteral;
};
std::ostream& operator<<(std::ostream& out, const Literal& literal);
// The underlying buffer and shape is always owned by this class.
class Literal : public MutableLiteralBase {
public:
Literal() : Literal(ShapeUtil::MakeNil()) {}
// Create a literal of the given shape. The literal is allocated sufficient
// memory to hold the shape. Memory is uninitialized.
explicit Literal(const Shape& shape);
virtual ~Literal();
// Literals are moveable, but not copyable. To copy a literal use
// Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
// of literals which can be expensive.
Literal(const Literal& other) = delete;
Literal& operator=(const Literal& other) = delete;
Literal(Literal&& other);
// 'allocate_arrays' indicates whether to allocate memory for the arrays in
// the shape. If false, buffer pointers inside of the Literal::Pieces are set
// to nullptr.
Literal(const Shape& shape, bool allocate_arrays);
Literal& operator=(Literal&& other);
// Similar to CopyFrom, but with move semantincs. The subshape of this literal
// rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
// (layouts and shapes must match), but need not be arrays. The memory
// allocated in this literal for the subshape at dest_shape_index is
// deallocated, and the respective buffers are replaced with those in
// src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
virtual Status MoveFrom(Literal&& src_literal,
const ShapeIndex& dest_shape_index = {});
// Returns a vector containing the tuple elements of this Literal as separate
// Literals. This Literal must be tuple-shaped and can be a nested tuple. The
// elements are moved into the new Literals; no data is copied. Upon return
// this Literal is set to a nil shape (empty tuple)
std::vector<Literal> DecomposeTuple();
private:
// Deallocate the buffers held by this literal.
void DeallocateBuffers();
friend class LiteralBase;
// Recursively sets the subshapes and buffers of all subpieces rooted at
// 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
// the shape.
void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
};
// The underlying buffer is not owned by this class and is always owned by
// others. The shape is not owned by this class and not mutable.
class MutableBorrowingLiteral : public MutableLiteralBase {
public:
virtual ~MutableBorrowingLiteral();
MutableBorrowingLiteral() : MutableLiteralBase() {}
MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
// Implicit conversion constructors.
MutableBorrowingLiteral(const MutableLiteralBase& literal);
MutableBorrowingLiteral(MutableLiteralBase* literal);
MutableBorrowingLiteral(MutableBorrowingLiteral literal,
const ShapeIndex& view_root);
MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
private:
// Recursively copies the subtree from the `src_piece` at the given child
// index to the `dest_piece`. For buffers only the pointers are copied, but
// not the content.
void CopyPieceSubtree(const Shape& shape, Piece* src_piece,
Piece* dest_piece);
};
std::ostream& operator<<(std::ostream& out, const Literal& literal);
// A read-only view of a Literal. A LiteralSlice contains pointers to shape and
// literal buffers always owned by others.
@ -831,9 +864,9 @@ class BorrowingLiteral : public LiteralBase {
const Piece& root_piece() const override { return root_piece_; };
Piece root_piece_;
// Shape of this literal. Stored as unique_ptr so such that the (default)
// move construction of this class would be trivially correct: the pointer to
// Shape root_piece_ stores will still point to the correct address.
// Shape of this literal. Stored as unique_ptr such that the (default) move
// construction of this class would be trivially correct: the pointer to Shape
// root_piece_ stores will still point to the correct address.
std::unique_ptr<Shape> shape_;
};
@ -886,7 +919,7 @@ tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
}
template <typename NativeT>
tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
tensorflow::gtl::MutableArraySlice<NativeT> MutableLiteralBase::data(
const ShapeIndex& shape_index) {
return piece(shape_index).data<NativeT>();
}
@ -904,14 +937,15 @@ inline NativeT LiteralBase::Get(
}
template <typename NativeT>
inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
inline void MutableLiteralBase::Set(
tensorflow::gtl::ArraySlice<int64> multi_index,
const ShapeIndex& shape_index, NativeT value) {
return piece(shape_index).Set<NativeT>(multi_index, value);
}
template <typename NativeT>
inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
NativeT value) {
inline void MutableLiteralBase::Set(
tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value) {
return root_piece().Set<NativeT>(multi_index, value);
}
@ -929,7 +963,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
}
template <typename NativeT>
void Literal::AppendSparseElement(
void MutableLiteralBase::AppendSparseElement(
tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
const ShapeIndex& shape_index) {
Piece& p = piece(shape_index);
@ -959,7 +993,8 @@ void LiteralBase::EachCell(
}
template <typename NativeT>
inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
inline void MutableLiteralBase::PopulateR1(
tensorflow::gtl::ArraySlice<NativeT> values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
@ -971,7 +1006,7 @@ inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
}
template <typename NativeT>
void Literal::PopulateR2(
void MutableLiteralBase::PopulateR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 2);
@ -996,7 +1031,7 @@ void Literal::PopulateR2(
}
template <typename NativeT>
void Literal::PopulateFromArray(const Array<NativeT>& values) {
void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(shape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>());
@ -1009,23 +1044,23 @@ void Literal::PopulateFromArray(const Array<NativeT>& values) {
}
template <typename NativeT>
void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
PopulateFromArray(values);
}
template <typename NativeT>
void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
PopulateFromArray(values);
}
template <typename NativeT>
void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
PopulateFromArray(values);
}
template <typename NativeT>
void Literal::PopulateSparse(SparseIndexArray indices,
tensorflow::gtl::ArraySlice<NativeT> values,
void MutableLiteralBase::PopulateSparse(
SparseIndexArray indices, tensorflow::gtl::ArraySlice<NativeT> values,
bool sort) {
CHECK(LayoutUtil::IsSparseArray(shape()));
int rank = ShapeUtil::Rank(shape());
@ -1049,7 +1084,8 @@ void Literal::PopulateSparse(SparseIndexArray indices,
}
template <typename NativeT, typename FnType>
Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
Status MutableLiteralBase::PopulateInternal(const FnType& generator,
bool parallel) {
const Shape& this_shape = shape();
const int64 rank = ShapeUtil::Rank(this_shape);
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
@ -1092,17 +1128,17 @@ Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
return Status::OK();
}
template <typename NativeT, typename FnType>
Status Literal::Populate(const FnType& generator) {
Status MutableLiteralBase::Populate(const FnType& generator) {
return PopulateInternal<NativeT>(generator, /*parallel=*/false);
}
template <typename NativeT, typename FnType>
Status Literal::PopulateParallel(const FnType& generator) {
Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
return PopulateInternal<NativeT>(generator, /*parallel=*/true);
}
template <typename NativeT>
void Literal::PopulateWithValue(NativeT value) {
void MutableLiteralBase::PopulateWithValue(NativeT value) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(shape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>());

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::strings::StrCat;

View File

@ -570,7 +570,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
@ -613,6 +613,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda",
],
alwayslink = 1,
@ -1384,6 +1385,18 @@ tf_cc_test(
],
)
cc_library(
name = "while_loop_analysis",
srcs = ["while_loop_analysis.cc"],
hdrs = ["while_loop_analysis.h"],
deps = [
":hlo",
":hlo_evaluator",
"//tensorflow/compiler/xla:literal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "while_loop_simplifier",
srcs = ["while_loop_simplifier.cc"],
@ -1391,8 +1404,8 @@ cc_library(
deps = [
":call_inliner",
":hlo",
":hlo_evaluator",
":hlo_pass",
":while_loop_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
],

View File

@ -1803,6 +1803,12 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
}
Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
// TODO(b/112040122): Most of those optimizations can be done for multi-output
// reduces.
if (ShapeUtil::IsTuple(reduce->shape())) {
return Status::OK();
}
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());

View File

@ -48,11 +48,6 @@ namespace xla {
// compuation.
using ObjectFileData = std::vector<char>;
// Contains the buffer sizes information needed to allocate buffers to execute
// an ahead-of-time computation. Entries which contain -1 designate a parameter
// which should be skipped over during allocation.
using BufferSizes = std::vector<int64>;
// Abstract superclass describing the result of an ahead-of-time compilation.
class AotCompilationResult {
public:

View File

@ -54,12 +54,24 @@ cc_library(
alwayslink = True, # Contains per-platform transfer manager registration
)
cc_library(
name = "buffer_info_util",
srcs = ["buffer_info_util.cc"],
hdrs = ["buffer_info_util.h"],
deps = [
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/core:lib",
],
)
cc_library(
name = "cpu_compiler",
srcs = ["cpu_compiler.cc"],
hdrs = ["cpu_compiler.h"],
deps = [
":compiler_functor",
":buffer_info_util",
":conv_canonicalization",
":cpu_copy_insertion",
":cpu_executable",
@ -73,6 +85,7 @@ cc_library(
":ir_emitter",
":parallel_task_assignment",
":simple_orc_jit",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",

View File

@ -0,0 +1,57 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
namespace xla {
namespace cpu {
using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo;
std::vector<BufferInfo> CreateBufferInfosFromBufferAssignment(
const BufferAssignment& buffer_assignment) {
std::vector<BufferInfo> buffer_infos;
for (const BufferAllocation& allocation : buffer_assignment.Allocations()) {
if (allocation.is_thread_local()) {
buffer_infos.push_back(BufferInfo::MakeOnStackBuffer(allocation.size()));
} else if (allocation.is_constant()) {
buffer_infos.push_back(BufferInfo::MakeConstant(allocation.size()));
} else if (allocation.is_entry_computation_parameter()) {
buffer_infos.push_back(BufferInfo::MakeEntryParameter(
/*size=*/allocation.size(),
/*param_number=*/allocation.parameter_number()));
} else {
buffer_infos.push_back(BufferInfo::MakeTempBuffer(allocation.size()));
}
}
return buffer_infos;
}
std::vector<int32> CreateArgIndexTableFromBufferInfos(
tensorflow::gtl::ArraySlice<BufferInfo> buffer_infos) {
std::vector<int32> result;
for (int64 i = 0; i < buffer_infos.size(); i++) {
if (buffer_infos[i].is_entry_parameter()) {
if (buffer_infos[i].entry_parameter_number() >= result.size()) {
result.resize(buffer_infos[i].entry_parameter_number() + 1);
}
result[buffer_infos[i].entry_parameter_number()] = i;
}
}
return result;
}
} // namespace cpu
} // namespace xla

View File

@ -0,0 +1,42 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace cpu {
// Creates and returns a list of BufferInfo instances containing relevant
// information from `buffer_assignment`.
std::vector<::tensorflow::cpu_function_runtime::BufferInfo>
CreateBufferInfosFromBufferAssignment(
const BufferAssignment& buffer_assignment);
// Creates and returns a table containing the mapping from entry computation
// parameters to buffer allocation indices.
//
// If this function returns V then entry parameter i has buffer allocation index
// V[i].
std::vector<int32> CreateArgIndexTableFromBufferInfos(
tensorflow::gtl::ArraySlice<::tensorflow::cpu_function_runtime::BufferInfo>
buffer_infos);
} // namespace cpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_

View File

@ -50,6 +50,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
@ -103,6 +104,7 @@ limitations under the License.
namespace xla {
namespace cpu {
using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo;
CpuAotCompilationOptions::CpuAotCompilationOptions(
string triple, string cpu_name, string features, string entry_point_name,
@ -120,11 +122,11 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
}
CpuAotCompilationResult::CpuAotCompilationResult(
ObjectFileData object_file_data, BufferSizes buffer_sizes,
ObjectFileData object_file_data, std::vector<BufferInfo> buffer_infos,
int64 result_buffer_index,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
: object_file_data_(std::move(object_file_data)),
buffer_sizes_(std::move(buffer_sizes)),
buffer_infos_(std::move(buffer_infos)),
result_buffer_index_(result_buffer_index),
hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
@ -838,39 +840,14 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
ObjectFileData object_file_data(object_file->getBufferStart(),
object_file->getBufferEnd());
BufferSizes buffer_sizes;
for (const BufferAllocation& allocation : assignment->Allocations()) {
// Callers don't need to allocate anything for thread-local temporary
// buffers. They are lowered to allocas.
if (allocation.is_thread_local()) {
buffer_sizes.push_back(-1);
continue;
}
// Callers don't need to allocate anything for constant buffers. They are
// lowered to globals.
if (allocation.is_constant()) {
buffer_sizes.push_back(-1);
continue;
}
// Callers don't need to allocate anything for entry computation buffers,
// but they do need to stash the pointer to the entry computation buffer
// in the temp buffer table. See the comment on
// XlaCompiledCpuFunction::StaticData::temp_sizes.
if (allocation.is_entry_computation_parameter()) {
buffer_sizes.push_back(-allocation.parameter_number() - 2);
continue;
}
buffer_sizes.push_back(allocation.size());
}
std::vector<BufferInfo> buffer_infos =
CreateBufferInfosFromBufferAssignment(*assignment);
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment->GetUniqueTopLevelOutputSlice());
results.emplace_back(MakeUnique<CpuAotCompilationResult>(
std::move(object_file_data), std::move(buffer_sizes),
std::move(object_file_data), std::move(buffer_infos),
result_slice.index(), std::move(hlo_profile_printer_data)));
}

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
@ -78,7 +79,8 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
class CpuAotCompilationResult : public AotCompilationResult {
public:
CpuAotCompilationResult(
ObjectFileData object_file_data, BufferSizes buffer_sizes,
ObjectFileData object_file_data,
std::vector<::tensorflow::cpu_function_runtime::BufferInfo> buffer_infos,
int64 result_buffer_index,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data);
~CpuAotCompilationResult();
@ -88,17 +90,20 @@ class CpuAotCompilationResult : public AotCompilationResult {
}
const ObjectFileData& object_file_data() const { return object_file_data_; }
const BufferSizes& buffer_sizes() const { return buffer_sizes_; }
const std::vector<::tensorflow::cpu_function_runtime::BufferInfo>&
buffer_infos() const {
return buffer_infos_;
}
int64 result_buffer_index() const { return result_buffer_index_; }
private:
// Contains the compiled computation: an object file.
const ObjectFileData object_file_data_;
// The list of buffer sizes which should be allocated in order to execute the
// compiled computation. These buffers are used for temporary buffers used
// ephemerally during computation as well as the output result.
const BufferSizes buffer_sizes_;
// A list of BufferInfo objects describing the buffers used by the XLA
// computation.
const std::vector<::tensorflow::cpu_function_runtime::BufferInfo>
buffer_infos_;
// Contains which buffer index into |buffer_sizes| was designated to the
// result of the computation. This buffer should be passed into the output

View File

@ -173,7 +173,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor,
Status CpuTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* executor, const Shape& literal_shape,
Literal* literal) {
MutableBorrowingLiteral literal) {
if (!ShapeUtil::IsTuple(literal_shape)) {
int64 size = GetByteSizeRequirement(literal_shape);
// Note: OSS build didn't like implicit conversion from
@ -181,18 +181,16 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
tensorflow::gtl::ArraySlice<int64> dimensions(
tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()),
literal_shape.dimensions().size());
*literal = std::move(*LiteralUtil::CreateFromDimensions(
literal_shape.element_type(), dimensions));
TF_ASSIGN_OR_RETURN(Shape received_shape,
TransferArrayBufferFromOutfeed(
executor, literal->untyped_data(), size));
TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape()))
TF_ASSIGN_OR_RETURN(
Shape received_shape,
TransferArrayBufferFromOutfeed(executor, literal.untyped_data(), size));
TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape()))
<< "Shape received from outfeed "
<< ShapeUtil::HumanString(received_shape)
<< " did not match the shape that was requested for outfeed: "
<< ShapeUtil::HumanString(literal_shape);
TF_RET_CHECK(size == GetByteSizeRequirement(received_shape));
*literal->mutable_shape_do_not_use() = received_shape;
*literal.mutable_shape_do_not_use() = received_shape;
return Status::OK();
}
@ -201,22 +199,12 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
"Nested tuple outfeeds are not yet implemented on CPU.");
}
std::vector<std::unique_ptr<Literal>> elements;
std::vector<std::pair<void*, int64>> buffer_data;
for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
const Shape& tuple_element_shape =
ShapeUtil::GetTupleElementShape(literal_shape, i);
// Note: OSS build didn't like implicit conversion from
// literal_shape.dimensions() to the array slice on 2017-07-10.
tensorflow::gtl::ArraySlice<int64> dimensions(
tensorflow::bit_cast<const int64*>(
tuple_element_shape.dimensions().data()),
tuple_element_shape.dimensions().size());
auto empty = LiteralUtil::CreateFromDimensions(
tuple_element_shape.element_type(), dimensions);
int64 size = GetByteSizeRequirement(tuple_element_shape);
buffer_data.push_back({empty->untyped_data(), size});
elements.push_back(std::move(empty));
buffer_data.push_back({literal.untyped_data({i}), size});
}
TF_ASSIGN_OR_RETURN(Shape received_shape,
@ -230,11 +218,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
TF_RET_CHECK(GetByteSizeRequirement(literal_shape) ==
GetByteSizeRequirement(received_shape));
for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
*elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i);
}
*literal = std::move(*LiteralUtil::MakeTupleOwned(std::move(elements)));
TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape));
TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal_shape));
return Status::OK();
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@ -41,7 +42,7 @@ class CpuTransferManager : public GenericTransferManager {
const LiteralSlice& literal) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape,
Literal* literal) override;
MutableBorrowingLiteral literal) override;
private:
Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,

View File

@ -30,47 +30,6 @@ limitations under the License.
namespace xla {
namespace cpu {
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
switch (op->opcode()) {
case HloOpcode::kTanh: {
PrimitiveType element_type = op->shape().element_type();
bool cast_result_to_fp16 = false;
string function_name;
switch (element_type) {
case F16:
cast_result_to_fp16 = true;
operand_value = b_->CreateFPCast(operand_value, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "tanhf";
break;
case F64:
function_name = "tanh";
break;
default:
return Unimplemented("tanh");
}
// Create a function declaration.
llvm::Function* function =
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
llvm_ir::AsStringRef(function_name), operand_value->getType(),
operand_value->getType()));
function->setCallingConv(llvm::CallingConv::C);
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
llvm::Value* result = b_->CreateCall(function, operand_value);
if (cast_result_to_fp16) {
result = b_->CreateFPCast(result, b_->getHalfTy());
}
return result;
}
default:
return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value);
}
}
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
string function_name;
@ -106,6 +65,39 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
return result;
}
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(
PrimitiveType prim_type, llvm::Value* value) const {
bool cast_result_to_fp16 = false;
string function_name;
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
value = b_->CreateFPCast(value, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "tanhf";
break;
case F64:
function_name = "tanh";
break;
default:
return Unimplemented("tanh");
}
// Create a function declaration.
llvm::Function* function = llvm::cast<llvm::Function>(
module_->getOrInsertFunction(llvm_ir::AsStringRef(function_name),
value->getType(), value->getType()));
function->setCallingConv(llvm::CallingConv::C);
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
llvm::Value* result = b_->CreateCall(function, value);
if (cast_result_to_fp16) {
result = b_->CreateFPCast(result, b_->getHalfTy());
}
return result;
}
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator) const {

View File

@ -39,10 +39,10 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
const HloToElementGeneratorMap& operand_to_generator) const override;
protected:
StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) const override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) const override;
IrEmitter* ir_emitter_;
};

View File

@ -1756,6 +1756,10 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
}
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
// TODO(b/112040122): Support variadic reduce.
if (!ShapeUtil::IsArray(reduce->shape())) {
return Unimplemented("Variadic reduce is not supported on CPU");
}
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
gtl::ArraySlice<int64> dimensions(reduce->dimensions());

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY)
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "third_party/intel_mkl_ml/include/mkl_cblas.h"
#include "third_party/intel_mkl_ml/include/mkl_service.h"

View File

@ -106,6 +106,7 @@ class DfsHloVisitorBase {
virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
virtual Status HandleFft(HloInstructionPtr fft) = 0;
virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0;
virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
virtual Status HandleCompare(HloInstructionPtr hlo) {
return HandleElementwiseBinary(hlo);
}

View File

@ -94,6 +94,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
return DefaultAction(crs);
}
Status HandleAllToAll(HloInstructionPtr crs) override {
return DefaultAction(crs);
}
Status HandleRng(HloInstructionPtr random) override {
return DefaultAction(random);
}

View File

@ -431,6 +431,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
return EmitCos(op->shape().element_type(), operand_value);
case HloOpcode::kSin:
return EmitSin(op->shape().element_type(), operand_value);
case HloOpcode::kTanh:
return EmitTanh(op->shape().element_type(), operand_value);
case HloOpcode::kFloor:
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
{operand_value},
@ -1060,6 +1062,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
return Unimplemented("atan2");
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
llvm::Value* value) const {
return Unimplemented("tanh");
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
const HloInstruction* hlo, llvm::Value* x) const {
if (hlo->operand(0)->shape().element_type() != F32) {
@ -1239,13 +1246,23 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
// Convert raw integer to float in range [0, 1) if the element is a float.
llvm::Value* elem_value = raw_value;
if (elem_ir_ty->isFloatingPointTy()) {
elem_value = b_->CreateUIToFP(elem_value, elem_ir_ty);
unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits();
CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64);
// Perform the division using the float type with the same number of bits
// as the raw value to avoid overflow.
if (raw_value_size_in_bits == 32) {
elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy());
elem_value = b_->CreateFDiv(
elem_value,
llvm::ConstantFP::get(elem_ir_ty,
raw_value_size_in_bits == 64 ? 0x1p64 : 0x1p32));
elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32)));
} else {
elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy());
elem_value = b_->CreateFDiv(
elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64)));
}
if (elem_ir_ty != elem_value->getType()) {
elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty);
}
}
// Convert the value for the requested distribution.
@ -1302,6 +1319,7 @@ int32 GetNumberOfElementsPerPhiloxRngSample(PrimitiveType elem_prim_ty) {
case F16:
return 4;
case U64:
case S64:
case F64:
return 2;
default:

View File

@ -122,6 +122,9 @@ class ElementalIrEmitter {
llvm::Value* lhs,
llvm::Value* rhs) const;
virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) const;
virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
llvm::Value* x) const;

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -60,17 +59,19 @@ Status GenericTransferManager::WriteSingleTupleIndexTable(
void GenericTransferManager::TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer,
std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) {
MutableBorrowingLiteral literal, std::function<void(Status)> done) {
Status status = stream->BlockHostUntilDone();
if (!status.ok()) {
return done(status);
}
done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer));
done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer,
literal));
}
StatusOr<std::unique_ptr<Literal>>
GenericTransferManager::TransferLiteralFromDeviceInternal(
se::StreamExecutor* executor, const ShapedBuffer& device_buffer) {
Status GenericTransferManager::TransferLiteralFromDeviceInternal(
se::StreamExecutor* executor, const ShapedBuffer& device_buffer,
MutableBorrowingLiteral literal) {
VLOG(2) << "transferring literal from device ordinal "
<< executor->device_ordinal() << "; device buffer: " << device_buffer;
TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
@ -80,9 +81,6 @@ GenericTransferManager::TransferLiteralFromDeviceInternal(
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
device_buffer.on_host_shape()));
std::unique_ptr<Literal> literal =
Literal::CreateFromShape(device_buffer.on_host_shape());
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_host_shape(),
[&](const Shape& subshape, const ShapeIndex& index) -> Status {
@ -91,12 +89,12 @@ GenericTransferManager::TransferLiteralFromDeviceInternal(
/*source=*/device_buffer.buffer(index),
/*size=*/GetByteSizeRequirement(subshape),
/*destination=*/
literal->untyped_data(index)));
literal.untyped_data(index)));
}
return Status::OK();
}));
return std::move(literal);
return Status::OK();
}
Status GenericTransferManager::TransferLiteralToDeviceAsync(
@ -160,7 +158,7 @@ Status GenericTransferManager::TransferLiteralToInfeed(
Status GenericTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* executor, const Shape& literal_shape,
Literal* literal) {
MutableBorrowingLiteral literal) {
return Unimplemented("Generic transfer from Outfeed");
}

View File

@ -19,7 +19,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -41,9 +40,10 @@ class GenericTransferManager : public TransferManager {
se::Platform::Id PlatformId() const override;
void TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer,
std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) override;
void TransferLiteralFromDevice(se::Stream* stream,
const ShapedBuffer& device_buffer,
MutableBorrowingLiteral literal,
std::function<void(Status)> done) override;
Status TransferLiteralToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal,
@ -53,7 +53,7 @@ class GenericTransferManager : public TransferManager {
const LiteralSlice& literal) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape,
Literal* literal) override;
MutableBorrowingLiteral literal) override;
Status ResetDevices(
tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override;
@ -67,8 +67,9 @@ class GenericTransferManager : public TransferManager {
const Shape& shape, se::DeviceMemoryBase* region) override;
private:
StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDeviceInternal(
se::StreamExecutor* executor, const ShapedBuffer& device_buffer);
Status TransferLiteralFromDeviceInternal(se::StreamExecutor* executor,
const ShapedBuffer& device_buffer,
MutableBorrowingLiteral literal);
// The platform this transfer manager targets.
const se::Platform::Id platform_id_;

View File

@ -153,7 +153,6 @@ cc_library(
":ir_emission_utils",
":parallel_loop_emitter",
":partition_assignment",
":while_transformer",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -166,6 +165,7 @@ cc_library(
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service:while_loop_analysis",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
@ -655,7 +655,6 @@ cc_library(
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/compiler/xla/service:while_loop_constant_sinking",
"//tensorflow/compiler/xla/service:while_loop_invariant_code_motion",
"//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/compiler/xla/service:zero_sized_hlo_elimination",
"//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter",
@ -788,32 +787,17 @@ tf_cc_test(
],
)
cc_library(
name = "while_transformer",
srcs = ["while_transformer.cc"],
hdrs = ["while_transformer.h"],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "while_transformer_test",
srcs = ["while_transformer_test.cc"],
deps = [
":instruction_fusion",
":while_transformer",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:while_loop_analysis",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla {
namespace gpu {
@ -137,6 +138,28 @@ string NumBytesToString(int64 bytes) {
tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)");
}
// Acquires a process-global lock on the device pointed to by the given
// StreamExecutor.
//
// This is used to prevent other XLA instances from trying to autotune on this
// device while we're using it.
tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
// se::Platform*s are global singletons guaranteed to live forever.
static auto* mutexes =
new std::map<std::pair<const se::Platform*, /*device_ordinal*/ int64>,
tensorflow::mutex>();
tensorflow::mutex_lock global_lock(mu);
auto it = mutexes
->emplace(std::piecewise_construct,
std::make_tuple(stream_exec->platform(),
stream_exec->device_ordinal()),
std::make_tuple())
.first;
return tensorflow::mutex_lock{it->second};
}
} // anonymous namespace
// We could have caching here so that we don't redo this work for two identical
@ -155,6 +178,13 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) {
// Don't run this function concurrently on the same GPU.
//
// This is a bit of a hack and doesn't protect us against arbitrary concurrent
// use of a GPU, but it's sufficient to let us compile two HLO modules
// concurrently and then run them sequentially.
tensorflow::mutex_lock lock = LockGpu(stream_exec_);
// Create a stream for us to do our work on.
se::Stream stream{stream_exec_};
stream.Init();

View File

@ -272,27 +272,18 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
prim_type);
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
PrimitiveType input_type = op->operand(0)->shape().element_type();
PrimitiveType output_type = op->shape().element_type();
switch (op->opcode()) {
case HloOpcode::kTanh:
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
PrimitiveType prim_type, llvm::Value* value) const {
// If we don't care much about precision, emit a fast approximation of
// tanh.
if (hlo_module_config_.debug_options().xla_enable_fast_math()) {
// Upcast F16 to F32 if necessary.
llvm::Type* type =
input_type == F16 ? b_->getFloatTy() : operand_value->getType();
llvm::Value* input = b_->CreateFPCast(operand_value, type);
llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
llvm::Value* input = b_->CreateFPCast(value, type);
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
return b_->CreateFPCast(fast_tanh, operand_value->getType());
}
return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type},
output_type);
default:
return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value);
return b_->CreateFPCast(fast_tanh, value->getType());
}
return EmitLibdeviceMathCall("__nv_tanh", {value}, {prim_type}, prim_type);
}
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
@ -445,6 +436,8 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
return b_->CreateLoad(accum_ptr);
};
case HloOpcode::kReduce:
// TODO(b/112040122): This should be supported.
CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce";
return [=, &operand_to_generator](
const IrArray::Index& output_index) -> StatusOr<llvm::Value*> {
const HloInstruction* operand = hlo->operand(0);

View File

@ -51,9 +51,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
const HloToElementGeneratorMap& operand_to_generator) const override;
protected:
StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;
StatusOr<llvm::Value*> EmitFloatBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const override;
@ -85,6 +82,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) const override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) const override;
llvm::Value* EmitThreadId() const override;
private:

View File

@ -52,12 +52,12 @@ class GemmThunk : public Thunk {
se::Stream* stream,
HloExecutionProfiler* profiler) override;
// Returns true if we'll perform autotuning if run on the given stream. If
// so, we want the GPU to be quiescent during autotuning, so as not to
// introduce noise in our results.
bool ShouldHaltAllActivityBeforeRunning(se::Stream* stream) override {
return autotune_results_.count(
stream->parent()->GetDeviceDescription().name()) != 0;
bool WillAutotuneKernel(se::Stream* stream) override {
// We will autotune this kernel if we don't already have a autotune result
// for the stream device.
return autotune_results_.find(
stream->parent()->GetDeviceDescription().name()) ==
autotune_results_.end();
}
private:
@ -75,6 +75,8 @@ class GemmThunk : public Thunk {
// results. The map's value is the best algorithm we've found for this thunk
// on this device, or an error if none of the algorithms worked and we should
// use the regular gemm without an algorithm.
//
// TODO(b/112415150): Make this thread safe.
std::unordered_map<string, StatusOr<se::blas::AlgorithmType>>
autotune_results_;
};

View File

@ -131,9 +131,10 @@ Status GpuExecutable::ExecuteThunks(
stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get());
}
// If this thunk requests it, wait for all currently-executing thunks to
// finish. This is useful e.g. if the thunk is about to perform autotuning.
if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) {
// If this thunk is about to autotune then wait for all currently executing
// thunks to finish. This reduces noise and thus the probability of
// choosing a suboptimal algorithm.
if (thunk->WillAutotuneKernel(stream)) {
TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone());
}

Some files were not shown because too many files have changed in this diff Show More