Multiple blocks of code in StreamExecutor can be simplified using the

TF_ASSIGN_OR_RETURN macro, and StatusOr class is already defined in
StreamExecutor.

PiperOrigin-RevId: 249491049
This commit is contained in:
George Karpenkov 2019-05-22 11:52:57 -07:00 committed by TensorFlower Gardener
parent 02cd8bdcfa
commit 259b05ecad
4 changed files with 36 additions and 38 deletions

View File

@ -187,28 +187,4 @@ class StatusAdaptorForMacros {
.with_log_stack_trace() \
.add_ret_check_failure(#condition)
#define TF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \
TF_ASSERT_OK_AND_ASSIGN_IMPL( \
TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \
rexpr);
#define TF_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \
auto statusor = (rexpr); \
ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \
lhs = std::move(statusor.ValueOrDie())
#define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y)
#define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y
#define TF_ASSIGN_OR_RETURN(lhs, rexpr) \
TF_ASSIGN_OR_RETURN_IMPL( \
TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr)
#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \
auto statusor = (rexpr); \
if (TF_PREDICT_FALSE(!statusor.ok())) { \
return statusor.status(); \
} \
lhs = std::move(statusor.ValueOrDie())
#endif // TENSORFLOW_COMPILER_XLA_STATUS_MACROS_H_

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/platform/subprocess.h"
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace stream_executor {
namespace cuda {
@ -120,11 +121,9 @@ port::StatusOr<absl::Span<const uint8>> CompilePtxOrGetCached(
compilation_options.ToTuple()};
auto it = ptx_cache.find(cache_key);
if (it == ptx_cache.end()) {
auto compiled_or = CompilePtx(device_ordinal, ptx, compilation_options);
TF_RETURN_IF_ERROR(compiled_or.status());
std::vector<uint8> compiled = std::move(compiled_or.ValueOrDie());
it =
ptx_cache.emplace(cache_key, std::move(compiled_or.ValueOrDie())).first;
TF_ASSIGN_OR_RETURN(std::vector<uint8> compiled,
CompilePtx(device_ordinal, ptx, compilation_options));
it = ptx_cache.emplace(cache_key, std::move(compiled)).first;
}
CHECK(it != ptx_cache.end());

View File

@ -306,6 +306,31 @@ void StatusOr<T>::IgnoreError() const {
}
} // namespace port
#define TF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \
TF_ASSERT_OK_AND_ASSIGN_IMPL( \
TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \
rexpr);
#define TF_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \
auto statusor = (rexpr); \
ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \
lhs = std::move(statusor.ValueOrDie())
#define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y)
#define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y
#define TF_ASSIGN_OR_RETURN(lhs, rexpr) \
TF_ASSIGN_OR_RETURN_IMPL( \
TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr)
#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \
auto statusor = (rexpr); \
if (TF_PREDICT_FALSE(!statusor.ok())) { \
return statusor.status(); \
} \
lhs = std::move(statusor.ValueOrDie())
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/stacktrace.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/lib/threadpool.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/rng.h"
@ -872,11 +873,9 @@ StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
port::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) {
port::StatusOr<StreamExecutor *> stream_executor_or =
GetStreamExecutor(device_ordinal);
TF_RETURN_IF_ERROR(stream_executor_or.status());
DeviceMemoryBase result =
stream_executor_or.ValueOrDie()->AllocateArray<uint8>(size);
TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
GetStreamExecutor(device_ordinal));
DeviceMemoryBase result = executor->AllocateArray<uint8>(size);
if (size > 0 && result == nullptr) {
return tensorflow::errors::ResourceExhausted(absl::StrFormat(
"Failed to allocate request for %s (%uB) on device ordinal %d",
@ -893,12 +892,11 @@ port::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
DeviceMemoryBase mem) {
if (!mem.is_null()) {
port::StatusOr<StreamExecutor *> stream_executor_or =
GetStreamExecutor(device_ordinal);
TF_RETURN_IF_ERROR(stream_executor_or.status());
TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
GetStreamExecutor(device_ordinal));
VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d",
mem.opaque(), device_ordinal);
stream_executor_or.ValueOrDie()->Deallocate(&mem);
executor->Deallocate(&mem);
}
return port::Status::OK();
}