Cleanup: Ran clang-format on files in tensorflow/core/.../*.{cc,h}.

PiperOrigin-RevId: 183848459
This commit is contained in:
A. Unique TensorFlower 2018-01-30 10:05:04 -08:00 committed by TensorFlower Gardener
parent 88eb6c61ef
commit 7149a2e2e2
150 changed files with 1098 additions and 1060 deletions

View File

@ -13,11 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -44,7 +42,6 @@ Tensor make_zeros(const DataType& dtype, const TensorShapeProto& shape) {
// third-party libraries aren't currently supported. // third-party libraries aren't currently supported.
class AccumulateNV2RemovePass : public GraphOptimizationPass { class AccumulateNV2RemovePass : public GraphOptimizationPass {
public: public:
Status Run(const GraphOptimizationPassOptions& options) override { Status Run(const GraphOptimizationPassOptions& options) override {
// TODO(freiss.oss@gmail.com): Substantial shared code with // TODO(freiss.oss@gmail.com): Substantial shared code with
// ParallelConcatRemovePass::Run(). Consider refactoring if someone makes // ParallelConcatRemovePass::Run(). Consider refactoring if someone makes

View File

@ -127,9 +127,9 @@ class BFCAllocator : public VisitableAllocator {
string DebugString(BFCAllocator* a, string DebugString(BFCAllocator* a,
bool recurse) NO_THREAD_SAFETY_ANALYSIS { bool recurse) NO_THREAD_SAFETY_ANALYSIS {
string dbg; string dbg;
strings::StrAppend(&dbg, " Size: ", strings::HumanReadableNumBytes(size), strings::StrAppend(
" | Requested Size: ", &dbg, " Size: ", strings::HumanReadableNumBytes(size),
strings::HumanReadableNumBytes(requested_size), " | Requested Size: ", strings::HumanReadableNumBytes(requested_size),
" | in_use: ", in_use()); " | in_use: ", in_use());
if (recurse && prev != BFCAllocator::kInvalidChunkHandle) { if (recurse && prev != BFCAllocator::kInvalidChunkHandle) {
Chunk* p = a->ChunkFromHandle(prev); Chunk* p = a->ChunkFromHandle(prev);

View File

@ -88,7 +88,9 @@ TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) {
// D3 is prioritized below D1. // D3 is prioritized below D1.
AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0"); AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0");
EXPECT_EQ((std::vector<DeviceType>{ EXPECT_EQ((std::vector<DeviceType>{
DeviceType("d2"), DeviceType("d1"), DeviceType("d3"), DeviceType("d2"),
DeviceType("d1"),
DeviceType("d3"),
}), }),
types()); types());
} }

View File

@ -61,7 +61,6 @@ limitations under the License.
#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/env_var.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -472,9 +471,9 @@ Status DirectSession::Run(const RunOptions& run_options,
Executor::Args args; Executor::Args args;
args.step_id = step_id_counter_.fetch_add(1); args.step_id = step_id_counter_.fetch_add(1);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
GetOrCreateExecutors(input_tensor_names, output_names, target_nodes, target_nodes, &executors_and_keys,
&executors_and_keys, &run_state_args)); &run_state_args));
const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1); const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
std::unique_ptr<DebuggerStateInterface> debugger_state; std::unique_ptr<DebuggerStateInterface> debugger_state;

View File

@ -436,10 +436,7 @@ TEST(DirectSessionTest, FetchMultipleTimes) {
} }
} }
REGISTER_OP("Darth") REGISTER_OP("Darth").Input("x: float").Output("y: float").Doc(R"doc(
.Input("x: float")
.Output("y: float")
.Doc(R"doc(
Darth promises one return value. Darth promises one return value.
x: float x: float
@ -972,10 +969,9 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib,
std::atomic<int32> num_done(0); std::atomic<int32> num_done(0);
// Runs session to compute <node>:0 using inter_op thread pool <pool>. // Runs session to compute <node>:0 using inter_op thread pool <pool>.
auto add_session_run_call = [use_global_pools, &def, &options, &sessions, auto add_session_run_call =
&sessions_mu, [use_global_pools, &def, &options, &sessions, &sessions_mu, &num_done](
&num_done](thread::ThreadPool* tp, Node* node, thread::ThreadPool* tp, Node* node, int inter_op_pool) {
int inter_op_pool) {
auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu, auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu,
inter_op_pool, node, &num_done]() { inter_op_pool, node, &num_done]() {
RunOptions run_options; RunOptions run_options;

View File

@ -167,7 +167,7 @@ static void TestHWAccelerator(bool enableHWTrace) {
Node* y = test::graph::Matmul(&graph, a, x, false, false); Node* y = test::graph::Matmul(&graph, a, x, false, false);
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
Node* y_neg = test::graph::Unary(&graph, "Neg", y); Node* y_neg = test::graph::Unary(&graph, "Neg", y);

View File

@ -154,7 +154,8 @@ TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
a.DeallocateRaw(t3); a.DeallocateRaw(t3);
a.DeallocateRaw(t4); a.DeallocateRaw(t4);
} }
CheckStats(&a, 4097, 0, 1024 * sizeof(float) + 1048576 * sizeof(int64) + CheckStats(&a, 4097, 0,
1024 * sizeof(float) + 1048576 * sizeof(int64) +
2048 * sizeof(double) + 10485760 * sizeof(float), 2048 * sizeof(double) + 10485760 * sizeof(float),
10485760 * sizeof(float)); 10485760 * sizeof(float));

View File

@ -763,8 +763,9 @@ int64 MinSystemMemory(int64 available_memory) {
min_system_memory *= 2; min_system_memory *= 2;
#endif #endif
#if defined(NVIDIA_TEGRA) #if defined(NVIDIA_TEGRA)
// 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM and Video RAM // 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM
min_system_memory = 1<<30; // and Video RAM
min_system_memory = 1 << 30;
#endif #endif
return min_system_memory; return min_system_memory;
} }

View File

@ -108,7 +108,8 @@ TEST_F(GpuStreamUtilTest, StreamOverrides) {
ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0, ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0,
"/device:GPU:0"); "/device:GPU:0");
Output n = ops::MatMul(root, {}, {}); Output n = ops::MatMul(root, {}, {});
ops::_Send(root.WithOpName("output"), n, "output", "/device:GPU:0", 0, "/cpu:0"); ops::_Send(root.WithOpName("output"), n, "output", "/device:GPU:0", 0,
"/cpu:0");
Graph g(OpRegistry::Global()); Graph g(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&g)); TF_ASSERT_OK(root.ToGraph(&g));

View File

@ -88,8 +88,8 @@ ProcessState::~ProcessState() {
} }
string ProcessState::MemDesc::DebugString() { string ProcessState::MemDesc::DebugString() {
return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index, ", dma: ", return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index,
gpu_registered, ", nic: ", nic_registered); ", dma: ", gpu_registered, ", nic: ", nic_registered);
} }
ProcessState::MemDesc ProcessState::PtrType(const void* ptr) { ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {

View File

@ -139,9 +139,7 @@ class GraphExecutionState {
// The graph returned by BuildGraph may contain only the pruned // The graph returned by BuildGraph may contain only the pruned
// graph, whereas some clients may want access to the full graph. // graph, whereas some clients may want access to the full graph.
const Graph* full_graph() { const Graph* full_graph() { return graph_; }
return graph_;
}
// Returns the node with the given name, or null if it does not exist. // Returns the node with the given name, or null if it does not exist.
const Node* get_node_by_name(const string& name) const { const Node* get_node_by_name(const string& name) const {

View File

@ -47,7 +47,7 @@ struct EndpointEq {
static Status ProcessMemoryTypes( static Status ProcessMemoryTypes(
const DeviceType& device_type, const Graph* g, const DeviceType& device_type, const Graph* g,
const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) { const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) {
if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL ) { if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL) {
// On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always // On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always
// compatible. // compatible.
return Status::OK(); return Status::OK();

View File

@ -619,8 +619,8 @@ TEST_F(PlacerTest, TestReferenceConnectionIgnoreInfeasible) {
Node* input = ops::SourceOp( Node* input = ops::SourceOp(
"TestDevice", "TestDevice",
b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0")); b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0"));
Node* var = ops::SourceOp("TestVariable", Node* var =
b.opts().WithName("var_0").WithDevice( ops::SourceOp("TestVariable", b.opts().WithName("var_0").WithDevice(
"/job:a/task:0/device:fakegpu:0")); "/job:a/task:0/device:fakegpu:0"));
// This op is specified on CPU, but in practice will be ignored, // This op is specified on CPU, but in practice will be ignored,

View File

@ -60,8 +60,8 @@ const string RegisteredFactoriesErrorMessageLocked() {
str_util::Join(factory_types, ", "), "}."); str_util::Join(factory_types, ", "), "}.");
} }
string SessionOptionsToString(const SessionOptions& options) { string SessionOptionsToString(const SessionOptions& options) {
return strings::StrCat("target: \"", options.target, "\" config: ", return strings::StrCat("target: \"", options.target,
ProtoShortDebugString(options.config)); "\" config: ", ProtoShortDebugString(options.config));
} }
} // namespace } // namespace

View File

@ -226,12 +226,14 @@ void StepStatsCollector::BuildCostModel(
if (node) { if (node) {
for (int i = 0; i < stats.output_size(); ++i) { for (int i = 0; i < stats.output_size(); ++i) {
const auto& output = stats.output(i); const auto& output = stats.output(i);
cm->RecordMaxMemorySize(node, i, Bytes(output.tensor_description() cm->RecordMaxMemorySize(node, i,
Bytes(output.tensor_description()
.allocation_description() .allocation_description()
.allocated_bytes()), .allocated_bytes()),
stats.output(i).tensor_description().shape(), stats.output(i).tensor_description().shape(),
node->output_types()[i]); node->output_types()[i]);
cm->RecordAllocationId(node, i, output.tensor_description() cm->RecordAllocationId(node, i,
output.tensor_description()
.allocation_description() .allocation_description()
.allocation_id()); .allocation_id());
} }
@ -239,8 +241,7 @@ void StepStatsCollector::BuildCostModel(
// Use hardware stats to record the execution time if they're available, // Use hardware stats to record the execution time if they're available,
// otherwise use the regular (less accurate) stats // otherwise use the regular (less accurate) stats
string node_name = dev_stats.regular_stats->node_stats(i).node_name(); string node_name = dev_stats.regular_stats->node_stats(i).node_name();
if (dev_stats.hardware_stats && if (dev_stats.hardware_stats && name_to_hw_node_stats.find(node_name) !=
name_to_hw_node_stats.find(node_name) !=
name_to_hw_node_stats.end()) { name_to_hw_node_stats.end()) {
const NodeExecStats& hw_stats = name_to_hw_node_stats[node_name]; const NodeExecStats& hw_stats = name_to_hw_node_stats[node_name];
cm->RecordMaxExecutionTime( cm->RecordMaxExecutionTime(

View File

@ -80,7 +80,7 @@ void SYCLAllocator::ClearStats() override {
size_t SYCLAllocator::RequestedSize(void* ptr) { size_t SYCLAllocator::RequestedSize(void* ptr) {
mutex_lock lock(mu_); mutex_lock lock(mu_);
if(!sycl_device_) { if (!sycl_device_) {
return 0; return 0;
} }
const auto& buffer = sycl_device_->get_sycl_buffer(ptr); const auto& buffer = sycl_device_->get_sycl_buffer(ptr);

View File

@ -20,10 +20,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_ #ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
#define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_ #define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow { namespace tensorflow {
@ -56,14 +56,13 @@ class SYCLAllocator : public Allocator {
// Clear the SYCL device used by the Allocator // Clear the SYCL device used by the Allocator
void ClearSYCLDevice() { void ClearSYCLDevice() {
mutex_lock lock(mu_); mutex_lock lock(mu_);
if(sycl_device_) { if (sycl_device_) {
delete sycl_device_; delete sycl_device_;
sycl_device_ = nullptr; sycl_device_ = nullptr;
} }
} }
private: private:
mutable mutex mu_; mutable mutex mu_;
Eigen::SyclDevice* sycl_device_ GUARDED_BY(mu_); // owned Eigen::SyclDevice* sycl_device_ GUARDED_BY(mu_); // owned
AllocatorStats stats_ GUARDED_BY(mu_); AllocatorStats stats_ GUARDED_BY(mu_);

View File

@ -187,9 +187,9 @@ class GSYCLInterface {
type = "Unknown"; type = "Unknown";
} }
return strings::StrCat("id: ", device_id, ", type: ", type, ", name: ", return strings::StrCat(
name.c_str(), ", vendor: ", vendor.c_str(), "id: ", device_id, ", type: ", type, ", name: ", name.c_str(),
", profile: ", profile.c_str()); ", vendor: ", vendor.c_str(), ", profile: ", profile.c_str());
} }
}; };

View File

@ -26,7 +26,6 @@ class SYCLDeviceFactory : public DeviceFactory {
public: public:
Status CreateDevices(const SessionOptions &options, const string &name_prefix, Status CreateDevices(const SessionOptions &options, const string &name_prefix,
std::vector<Device *> *devices) override { std::vector<Device *> *devices) override {
auto syclInterface = GSYCLInterface::instance(); auto syclInterface = GSYCLInterface::instance();
size_t n = 1; size_t n = 1;
@ -37,13 +36,11 @@ class SYCLDeviceFactory : public DeviceFactory {
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
string name = strings::StrCat(name_prefix, "/device:SYCL:", i); string name = strings::StrCat(name_prefix, "/device:SYCL:", i);
devices->push_back( devices->push_back(new SYCLDevice(
new SYCLDevice(options, name, Bytes(256 << 20), DeviceLocality() options, name, Bytes(256 << 20), DeviceLocality(),
, syclInterface->GetShortDeviceDescription(i) syclInterface->GetShortDeviceDescription(i),
, syclInterface->GetSYCLAllocator(i) syclInterface->GetSYCLAllocator(i), syclInterface->GetCPUAllocator(i),
, syclInterface->GetCPUAllocator(i) syclInterface->GetSYCLContext(i)));
, syclInterface->GetSYCLContext(i))
);
} }
return Status::OK(); return Status::OK();
@ -51,6 +48,6 @@ class SYCLDeviceFactory : public DeviceFactory {
}; };
REGISTER_LOCAL_DEVICE_FACTORY("SYCL", SYCLDeviceFactory, 200); REGISTER_LOCAL_DEVICE_FACTORY("SYCL", SYCLDeviceFactory, 200);
} } // namespace tensorflow
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL

View File

@ -20,8 +20,8 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
#include "tensorflow/core/common_runtime/device.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device.h"
// For DMA helper // For DMA helper
#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"

View File

@ -24,25 +24,25 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
DebugGateway::DebugGateway(DirectSession* session) : session_(session) { DebugGateway::DebugGateway(DirectSession* session) : session_(session) {
session_->node_outputs_callback_ = [this]( session_->node_outputs_callback_ =
const string& node_name, const int output_slot, const Tensor* tensor, [this](const string& node_name, const int output_slot,
const bool is_ref, OpKernelContext* ctx) { const Tensor* tensor, const bool is_ref, OpKernelContext* ctx) {
if (comp_cb_ != nullptr && output_slot <= 0) { if (comp_cb_ != nullptr && output_slot <= 0) {
// The node completion callback is invoked once for a node regardless // The node completion callback is invoked once for a node regardless
// of whether the node has zero, one or more outputs. // of whether the node has zero, one or more outputs.
// The output_slot can be negative (-1, or kControlSlot) if // The output_slot can be negative (-1, or kControlSlot) if
// node_outputs_callback_ is invoked for a node with no output. If that // node_outputs_callback_ is invoked for a node with no output. If
// is the case, notify the callback that the node in question has no // that is the case, notify the callback that the node in question has
// output. // no output.
comp_cb_(node_name, output_slot == 0); comp_cb_(node_name, output_slot == 0);
} }
// Copy tensor values (e.g., from GPU to host) only if the // Copy tensor values (e.g., from GPU to host) only if the
// value callback is not nullptr. // value callback is not nullptr.
if (val_cb_ != nullptr && output_slot >= 0) { if (val_cb_ != nullptr && output_slot >= 0) {
CopyTensor( CopyTensor(node_name, output_slot, tensor, ctx,
node_name, output_slot, tensor, ctx, [this, node_name, output_slot,
[this, node_name, output_slot, is_ref](const Tensor* copied_tensor) { is_ref](const Tensor* copied_tensor) {
val_cb_(node_name, output_slot, *copied_tensor, is_ref); val_cb_(node_name, output_slot, *copied_tensor, is_ref);
}); });
} }
@ -86,7 +86,8 @@ void DebugGateway::CopyTensor(const string& node_name, const int output_slot,
// Determine if the tensor is on device (GPU) or host (CPU). // Determine if the tensor is on device (GPU) or host (CPU).
// The second part of the check is necessary because even an OpKernel on // The second part of the check is necessary because even an OpKernel on
// may have output tensors allocated on CPU. // may have output tensors allocated on CPU.
if ((device->name().find("GPU:") != string::npos || device->name().find("SYCL:") != string::npos) && if ((device->name().find("GPU:") != string::npos ||
device->name().find("SYCL:") != string::npos) &&
!ctx->output_alloc_attr(output_slot).on_host()) { !ctx->output_alloc_attr(output_slot).on_host()) {
// GPU tensors: Copy it to host (CPU). // GPU tensors: Copy it to host (CPU).
DeviceContext* device_ctxt = ctx->op_device_context(); DeviceContext* device_ctxt = ctx->op_device_context();

View File

@ -390,8 +390,8 @@ TEST_F(SessionDebugMinusAXTest,
debug_gateway.SetNodeValueCallback( debug_gateway.SetNodeValueCallback(
[this, &mu, &val_callback_count, &a_debug_identity_node_name, [this, &mu, &val_callback_count, &a_debug_identity_node_name,
&x_debug_identity_node_name, &y_debug_identity_node_name, &x_debug_identity_node_name, &y_debug_identity_node_name,
&debug_identity_tensor_vals, &callbacks_done, &kConcurrentRuns]( &debug_identity_tensor_vals, &callbacks_done,
const string& node_name, const int output_slot, &kConcurrentRuns](const string& node_name, const int output_slot,
const Tensor& tensor_value, const bool is_ref) { const Tensor& tensor_value, const bool is_ref) {
mutex_lock l(mu); mutex_lock l(mu);
@ -560,8 +560,8 @@ TEST_F(SessionDebugOutputSlotWithoutOutgoingEdgeTest,
Notification callbacks_done; Notification callbacks_done;
std::vector<Tensor> debug_identity_tensor_vals; std::vector<Tensor> debug_identity_tensor_vals;
debug_gateway.SetNodeValueCallback([this, &mu, &callbacks_done, debug_gateway.SetNodeValueCallback(
&debug_identity_node_name, [this, &mu, &callbacks_done, &debug_identity_node_name,
&debug_identity_tensor_vals]( &debug_identity_tensor_vals](
const string& node_name, const int output_slot, const string& node_name, const int output_slot,
const Tensor& tensor_value, const bool is_ref) { const Tensor& tensor_value, const bool is_ref) {

View File

@ -30,7 +30,7 @@ namespace test {
::grpc::Status TestEventListenerImpl::SendEvents( ::grpc::Status TestEventListenerImpl::SendEvents(
::grpc::ServerContext* context, ::grpc::ServerContext* context,
::grpc::ServerReaderWriter< ::tensorflow::EventReply, ::tensorflow::Event>* ::grpc::ServerReaderWriter<::tensorflow::EventReply, ::tensorflow::Event>*
stream) { stream) {
Event event; Event event;

View File

@ -57,7 +57,8 @@ class DebugIOUtilsTest : public ::testing::Test {
TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) { TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) {
DebugNodeKey debug_node_key("/job:worker/replica:1/task:0/device:GPU:2", DebugNodeKey debug_node_key("/job:worker/replica:1/task:0/device:GPU:2",
"hidden_1/MatMul", 0, "DebugIdentity"); "hidden_1/MatMul", 0, "DebugIdentity");
EXPECT_EQ("/job:worker/replica:1/task:0/device:GPU:2", debug_node_key.device_name); EXPECT_EQ("/job:worker/replica:1/task:0/device:GPU:2",
debug_node_key.device_name);
EXPECT_EQ("hidden_1/MatMul", debug_node_key.node_name); EXPECT_EQ("hidden_1/MatMul", debug_node_key.node_name);
EXPECT_EQ(0, debug_node_key.output_slot); EXPECT_EQ(0, debug_node_key.output_slot);
EXPECT_EQ("DebugIdentity", debug_node_key.debug_op); EXPECT_EQ("DebugIdentity", debug_node_key.debug_op);

View File

@ -185,18 +185,17 @@ class GrpcMasterService : public AsyncServiceInterface {
MutableRunStepResponseWrapper* wrapped_response = MutableRunStepResponseWrapper* wrapped_response =
new NonOwnedProtoRunStepResponse(&call->response); new NonOwnedProtoRunStepResponse(&call->response);
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
master_impl_->RunStep(call_opts, wrapped_request, wrapped_response, master_impl_->RunStep(
call_opts, wrapped_request, wrapped_response,
[call, call_opts, wrapped_request, wrapped_response, [call, call_opts, wrapped_request, wrapped_response,
trace](const Status& status) { trace](const Status& status) {
call->ClearCancelCallback(); call->ClearCancelCallback();
delete call_opts; delete call_opts;
delete wrapped_request; delete wrapped_request;
delete trace; delete trace;
if (call->request.store_errors_in_response_body() && if (call->request.store_errors_in_response_body() && !status.ok()) {
!status.ok()) {
call->response.set_status_code(status.code()); call->response.set_status_code(status.code());
call->response.set_status_error_message( call->response.set_status_error_message(status.error_message());
status.error_message());
call->SendResponse(ToGrpcStatus(Status::OK())); call->SendResponse(ToGrpcStatus(Status::OK()));
} else { } else {
call->SendResponse(ToGrpcStatus(status)); call->SendResponse(ToGrpcStatus(status));

View File

@ -89,8 +89,8 @@ class MasterService final {
::grpc::Status ExtendSession(::grpc::ClientContext* context, ::grpc::Status ExtendSession(::grpc::ClientContext* context,
const ExtendSessionRequest& request, const ExtendSessionRequest& request,
ExtendSessionResponse* response) override; ExtendSessionResponse* response) override;
::grpc::Status PartialRunSetup( ::grpc::Status PartialRunSetup(::grpc::ClientContext* context,
::grpc::ClientContext* context, const PartialRunSetupRequest& request, const PartialRunSetupRequest& request,
PartialRunSetupResponse* response) override; PartialRunSetupResponse* response) override;
::grpc::Status RunStep(::grpc::ClientContext* context, ::grpc::Status RunStep(::grpc::ClientContext* context,
const RunStepRequest& request, const RunStepRequest& request,

View File

@ -69,8 +69,7 @@ class GrpcRemoteMaster : public MasterInterface {
::grpc::ClientContext ctx; ::grpc::ClientContext ctx;
auto trace = TraceRpc("RunStep/Client", &ctx); auto trace = TraceRpc("RunStep/Client", &ctx);
return Call(&ctx, call_options, &request->ToProto(), return Call(&ctx, call_options, &request->ToProto(),
get_proto_from_wrapper(response), get_proto_from_wrapper(response), &MasterServiceStub::RunStep);
&MasterServiceStub::RunStep);
} }
Status CloseSession(CallOptions* call_options, Status CloseSession(CallOptions* call_options,
@ -114,8 +113,9 @@ class GrpcRemoteMaster : public MasterInterface {
template <typename Request, typename Response> template <typename Request, typename Response>
Status Call(::grpc::ClientContext* ctx, CallOptions* call_options, Status Call(::grpc::ClientContext* ctx, CallOptions* call_options,
const Request* request, Response* response, const Request* request, Response* response,
::grpc::Status (MasterServiceStub::*pfunc)( ::grpc::Status (MasterServiceStub::*pfunc)(::grpc::ClientContext*,
::grpc::ClientContext*, const Request&, Response*)) { const Request&,
Response*)) {
ctx->set_fail_fast(false); ctx->set_fail_fast(false);
SetDeadline(ctx, call_options->GetTimeout()); SetDeadline(ctx, call_options->GetTimeout());
return FromGrpcStatus((stub_.get()->*pfunc)(ctx, *request, response)); return FromGrpcStatus((stub_.get()->*pfunc)(ctx, *request, response));

View File

@ -21,11 +21,8 @@ namespace tensorflow {
namespace test { namespace test {
// ErrorOp::Compute returns an error. // ErrorOp::Compute returns an error.
REGISTER_OP("Error") REGISTER_OP("Error").Input("in: T").Output("out: T").Attr("T: type").Attr(
.Input("in: T") "message: string");
.Output("out: T")
.Attr("T: type")
.Attr("message: string");
class ErrorOp : public OpKernel { class ErrorOp : public OpKernel {
public: public:
explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -66,11 +63,8 @@ REGISTER_KERNEL_BUILDER(Name("InvalidRefType").Device(DEVICE_CPU),
// DelayOp::AsyncCompute sleeps for "micros"-econd and then returns // DelayOp::AsyncCompute sleeps for "micros"-econd and then returns
// its input. // its input.
REGISTER_OP("Delay") REGISTER_OP("Delay").Input("in: T").Output("out: T").Attr("T: type").Attr(
.Input("in: T") "micros: int");
.Output("out: T")
.Attr("T: type")
.Attr("micros: int");
class DelayOp : public AsyncOpKernel { class DelayOp : public AsyncOpKernel {
public: public:
explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {

View File

@ -17,9 +17,9 @@ limitations under the License.
#include <queue> #include <queue>
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/util/util.h" #include "tensorflow/core/util/util.h"
namespace tensorflow { namespace tensorflow {

View File

@ -16,15 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_ #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_
#include <functional>
#include <deque> #include <deque>
#include <functional>
#include <map> #include <map>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "tensorflow/core/graph/costmodel.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/graph/costmodel.h"
namespace tensorflow { namespace tensorflow {

View File

@ -97,9 +97,8 @@ void WorkerCacheLogger::RecordDataTransfer(int64 step_id, int64 start_usecs,
const string& tensor_name, const string& tensor_name,
const string& src_device, const string& src_device,
const string& dst_device, const string& dst_device,
int64 bytes, int64 bytes, const string& details,
const string& details, const string& transfer_method_name) {
const string& transfer_method_name){
NodeExecStats* ns = new NodeExecStats; NodeExecStats* ns = new NodeExecStats;
ns->set_node_name(transfer_method_name); ns->set_node_name(transfer_method_name);
if (details.empty()) { if (details.empty()) {

View File

@ -158,8 +158,8 @@ void CostModel::SetNumOutputs(const Node* node, int num_outputs) {
Ensure(id, 0); Ensure(id, 0);
auto perslot = &slot_bytes_[id]; auto perslot = &slot_bytes_[id];
if (!perslot->empty()) { if (!perslot->empty()) {
CHECK_EQ(num_outputs, perslot->size()) << "Cannot resize slot_bytes, node=" CHECK_EQ(num_outputs, perslot->size())
<< node->name(); << "Cannot resize slot_bytes, node=" << node->name();
} }
Ensure(id, num_outputs); Ensure(id, num_outputs);
} }

View File

@ -198,7 +198,7 @@ class CostModel {
// Cumulative execution time. // Cumulative execution time.
std::vector<Microseconds> time_; std::vector<Microseconds> time_;
// Cumulative Bytes output on each channel. // Cumulative Bytes output on each channel.
std::vector<gtl::InlinedVector<Bytes, 2> > slot_bytes_; std::vector<gtl::InlinedVector<Bytes, 2>> slot_bytes_;
// Maximum execution time // Maximum execution time
std::vector<Microseconds> max_exec_time_; std::vector<Microseconds> max_exec_time_;
@ -217,7 +217,7 @@ class CostModel {
}; };
std::vector<MemUsage> max_mem_usage_; std::vector<MemUsage> max_mem_usage_;
std::vector<gtl::InlinedVector<int64, 2> > output_port_alloc_ids_; std::vector<gtl::InlinedVector<int64, 2>> output_port_alloc_ids_;
std::set<int64> persistent_alloc_ids_; std::set<int64> persistent_alloc_ids_;
std::map<string, std::set<int64>> persistent_alloc_ids_by_devices_; std::map<string, std::set<int64>> persistent_alloc_ids_by_devices_;

View File

@ -26,7 +26,6 @@ namespace tensorflow {
namespace { namespace {
TEST(GraphDefBuilderTest, Version) { TEST(GraphDefBuilderTest, Version) {
// Verify that our assertions will be nontrivial // Verify that our assertions will be nontrivial
ASSERT_LT(0, TF_GRAPH_DEF_VERSION); ASSERT_LT(0, TF_GRAPH_DEF_VERSION);

View File

@ -21,31 +21,31 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow { namespace tensorflow {
// Since our ops are going to produce and also consume N addition tensors // Since our ops are going to produce and also consume N addition tensors
// (Mkl) for N Tensorflow tensors, we can have following different // (Mkl) for N Tensorflow tensors, we can have following different
// orderings among these 2N tensors. // orderings among these 2N tensors.
// //
// E.g., for Tensorflow tensors A, B, and C, our ops will produce and // E.g., for Tensorflow tensors A, B, and C, our ops will produce and
// consume A_m, B_m, and C_m additionally. // consume A_m, B_m, and C_m additionally.
// //
// INTERLEAVED: in this case 2N tensors are interleaved. So for above // INTERLEAVED: in this case 2N tensors are interleaved. So for above
// example, the ordering looks like: A, A_m, B, B_m, C, C_m. // example, the ordering looks like: A, A_m, B, B_m, C, C_m.
// //
// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed // CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
// by N Mkl tensors. So for above example, the ordering looks // by N Mkl tensors. So for above example, the ordering looks
// like: A, B, C, A_m, B_m, C_m // like: A, B, C, A_m, B_m, C_m
// //
// Following APIs map index of original Tensorflow tensors to their // Following APIs map index of original Tensorflow tensors to their
// appropriate position based on selected ordering. For contiguous ordering, // appropriate position based on selected ordering. For contiguous ordering,
// we need to know the total number of tensors (parameter total). // we need to know the total number of tensors (parameter total).
// //
typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering; typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
// NOTE: Currently, we use contiguous ordering. If you change this, then you // NOTE: Currently, we use contiguous ordering. If you change this, then you
// would need to change Mkl op definitions in nn_ops.cc. // would need to change Mkl op definitions in nn_ops.cc.
static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS; static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
// Get index of MetaData tensor from index 'n' of Data tensor. // Get index of MetaData tensor from index 'n' of Data tensor.
inline int DataIndexToMetaDataIndex(int n, int total_tensors) { inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
// For interleaved ordering, Mkl tensor follows immediately after // For interleaved ordering, Mkl tensor follows immediately after
// Tensorflow tensor. // Tensorflow tensor.
@ -55,57 +55,56 @@ namespace tensorflow {
// For contiguous ordering, Mkl tensor is n+total_tensors / 2 away. // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
return n + total_tensors / 2; return n + total_tensors / 2;
} }
} }
int inline GetTensorDataIndex(int n, int total_tensors) { int inline GetTensorDataIndex(int n, int total_tensors) {
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
return 2 * n; // index corresponding to nth input/output tensor return 2 * n; // index corresponding to nth input/output tensor
} else { } else {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
return n; return n;
} }
} }
int inline GetTensorMetaDataIndex(int n, int total_tensors) { int inline GetTensorMetaDataIndex(int n, int total_tensors) {
// Get index for TensorData first and then use mapping function // Get index for TensorData first and then use mapping function
// to get TensorMetaData index from TensorData index. // to get TensorMetaData index from TensorData index.
int tidx = GetTensorDataIndex(n, total_tensors); int tidx = GetTensorDataIndex(n, total_tensors);
return DataIndexToMetaDataIndex(tidx, total_tensors); return DataIndexToMetaDataIndex(tidx, total_tensors);
} }
namespace mkl_op_registry { namespace mkl_op_registry {
static const char* kMklOpLabel = "MklOp"; static const char* kMklOpLabel = "MklOp";
static const char* kMklOpLabelPattern = "label='MklOp'"; static const char* kMklOpLabelPattern = "label='MklOp'";
// Prefix that we add to Tensorflow op name to construct Mkl op name. // Prefix that we add to Tensorflow op name to construct Mkl op name.
static const char* const kMklOpPrefix = "_Mkl"; static const char* const kMklOpPrefix = "_Mkl";
// Get the name of Mkl op from original TensorFlow op // Get the name of Mkl op from original TensorFlow op
// We prefix 'Mkl' to the original op to get Mkl op. // We prefix 'Mkl' to the original op to get Mkl op.
inline string GetMklOpName(const string& name) { inline string GetMklOpName(const string& name) {
return string(kMklOpPrefix) + name; return string(kMklOpPrefix) + name;
} }
// Check whether opname with type T is registered as MKL-compliant. // Check whether opname with type T is registered as MKL-compliant.
// //
// @input: name of the op // @input: name of the op
// @input: T datatype to be used for checking op // @input: T datatype to be used for checking op
// @return: true if opname is registered as Mkl op; false otherwise // @return: true if opname is registered as Mkl op; false otherwise
static inline bool IsMklOp(const std::string& op_name, DataType T) { static inline bool IsMklOp(const std::string& op_name, DataType T) {
string kernel = KernelsRegisteredForOp(op_name); string kernel = KernelsRegisteredForOp(op_name);
bool result = bool result =
kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT); kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
return result; return result;
} }
// Check whether opname with type T is registered as MKL-compliant and // Check whether opname with type T is registered as MKL-compliant and
// is element-wise. // is element-wise.
// //
// @input: name of the op // @input: name of the op
// @input: T datatype to be used for checking op // @input: T datatype to be used for checking op
// @return: true if opname is registered as element-wise Mkl op; // @return: true if opname is registered as element-wise Mkl op;
// false otherwise // false otherwise
static inline bool IsMklElementWiseOp(const std::string& op_name, static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) {
DataType T) {
if (!IsMklOp(op_name, T)) { if (!IsMklOp(op_name, T)) {
return false; return false;
} }
@ -116,7 +115,7 @@ namespace mkl_op_registry {
0 == op_name.compare(GetMklOpName("SquaredDifference"))); 0 == op_name.compare(GetMklOpName("SquaredDifference")));
return result; return result;
} }
} // namespace mkl_op_registry } // namespace mkl_op_registry
} // namespace tensorflow } // namespace tensorflow
#endif // INTEL_MKL #endif // INTEL_MKL

View File

@ -37,8 +37,8 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/graph/mkl_layout_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/mkl_layout_pass.h"
namespace tensorflow { namespace tensorflow {
@ -297,10 +297,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// End - element-wise ops. See note above. // End - element-wise ops. See note above.
// NOTE: names are alphabetically sorted. // NOTE: names are alphabetically sorted.
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN, rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
AddNRewrite, nullptr}); CopyAttrsAddN, AddNRewrite, nullptr});
rinfo_.push_back({csinfo_.add, rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
mkl_op_registry::GetMklOpName(csinfo_.add),
CopyAttrsDataType, AlwaysRewrite, nullptr}); CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.avg_pool, rinfo_.push_back({csinfo_.avg_pool,
mkl_op_registry::GetMklOpName(csinfo_.avg_pool), mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
@ -337,14 +336,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.fused_batch_norm, rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.fused_batch_norm_grad, rinfo_.push_back(
{csinfo_.fused_batch_norm_grad,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.identity, rinfo_.push_back({csinfo_.identity,
mkl_op_registry::GetMklOpName(csinfo_.identity), mkl_op_registry::GetMklOpName(csinfo_.identity),
CopyAttrsIdentity, AlwaysRewrite, nullptr}); CopyAttrsIdentity, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.lrn, rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
mkl_op_registry::GetMklOpName(csinfo_.lrn),
CopyAttrsLRN, AlwaysRewrite, nullptr}); CopyAttrsLRN, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.lrn_grad, rinfo_.push_back({csinfo_.lrn_grad,
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
@ -358,11 +357,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.maximum, rinfo_.push_back({csinfo_.maximum,
mkl_op_registry::GetMklOpName(csinfo_.maximum), mkl_op_registry::GetMklOpName(csinfo_.maximum),
CopyAttrsDataType, AlwaysRewrite, nullptr}); CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.mul, rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
mkl_op_registry::GetMklOpName(csinfo_.mul),
CopyAttrsDataType, AlwaysRewrite, nullptr}); CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.relu, rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
mkl_op_registry::GetMklOpName(csinfo_.relu),
CopyAttrsDataType, AlwaysRewrite, nullptr}); CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.relu_grad, rinfo_.push_back({csinfo_.relu_grad,
mkl_op_registry::GetMklOpName(csinfo_.relu_grad), mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
@ -373,8 +370,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.squared_difference, rinfo_.push_back({csinfo_.squared_difference,
mkl_op_registry::GetMklOpName(csinfo_.squared_difference), mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
CopyAttrsDataType, AlwaysRewrite, nullptr}); CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.sub, rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
mkl_op_registry::GetMklOpName(csinfo_.sub),
CopyAttrsDataType, AlwaysRewrite, nullptr}); CopyAttrsDataType, AlwaysRewrite, nullptr});
// Add info about which ops to add workspace edge to and the slots. // Add info about which ops to add workspace edge to and the slots.
@ -388,8 +384,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul, biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
IsBiasAddGradInMatMulContext}; IsBiasAddGradInMatMulContext};
biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad, biasaddgrad_conv2dwithbias_context_ = {
csinfo_.mkl_conv2d_with_bias, csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias,
IsBiasAddGradInConv2DWithBiasContext}; IsBiasAddGradInConv2DWithBiasContext};
cinfo_.push_back(&biasaddgrad_matmul_context_); cinfo_.push_back(&biasaddgrad_matmul_context_);
@ -615,8 +611,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
std::vector<int32> ksize, strides; std::vector<int32> ksize, strides;
CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true);
true);
CHECK_EQ(FormatFromString(data_format_str, &data_format), true); CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
// Condition that specifies non-batch-wise and non-depth-wise pooling. // Condition that specifies non-batch-wise and non-depth-wise pooling.
@ -785,8 +780,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
for (const Edge* fe : first_inp_of_filter->out_edges()) { for (const Edge* fe : first_inp_of_filter->out_edges()) {
if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
fe->dst_input() == 0) { fe->dst_input() == 0) {
VLOG(1) << "MklLayoutRewritePass: found " VLOG(1) << "MklLayoutRewritePass: found " << fe->dst()->DebugString()
<< fe->dst()->DebugString()
<< " as the forward node for matching context, backward" << " as the forward node for matching context, backward"
<< " node is: " << n->DebugString(); << " node is: " << n->DebugString();
*fwd_node = fe->dst(); *fwd_node = fe->dst();
@ -803,13 +797,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// //
// @return - true (if BiasAddGrad is associated with MatMul); // @return - true (if BiasAddGrad is associated with MatMul);
// false otherwise. // false otherwise.
static bool IsBiasAddGradInMatMulContext(const Node* n, static bool IsBiasAddGradInMatMulContext(const Node* n, const Node** fwd_node,
const Node** fwd_node,
void* ci) { void* ci) {
return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci)); return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci));
} }
// Rewrite rule that uses context-information for matching, // Rewrite rule that uses context-information for matching,
// used in scenario 2. // used in scenario 2.
// //
@ -880,8 +872,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// @output output_nodes - the list of new nodes creating Mkl tensors // @output output_nodes - the list of new nodes creating Mkl tensors
// //
// @return None // @return None
void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g, void GetNodesProducingMklTensorList(
Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, std::unique_ptr<Graph>* g, Node* orig_node,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
int* input_idx, int list_length, int* input_idx, int list_length,
std::vector<NodeBuilder::NodeOut>* output_nodes); std::vector<NodeBuilder::NodeOut>* output_nodes);
@ -900,7 +893,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// will feed the tensor // will feed the tensor
// @return None // @return None
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node, void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot); Node* n, int n_output_slot, Node** mkl_node,
int* mkl_node_output_slot);
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
@ -1060,8 +1054,8 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
// the same frame. // the same frame.
if (orig_node->num_inputs() > 0) { if (orig_node->num_inputs() > 0) {
Node* orig_input0 = nullptr; Node* orig_input0 = nullptr;
TF_CHECK_OK(orig_node->input_node(0, TF_CHECK_OK(
const_cast<const Node**>(&orig_input0))); orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
} }
@ -1069,11 +1063,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
} }
void MklLayoutRewritePass::GetNodesProducingMklTensorList( void MklLayoutRewritePass::GetNodesProducingMklTensorList(
std::unique_ptr<Graph>* g, std::unique_ptr<Graph>* g, Node* orig_node,
Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
int* input_idx, int list_length,
std::vector<NodeBuilder::NodeOut>* output_nodes) {
CHECK_LT(*input_idx, inputs.size()); CHECK_LT(*input_idx, inputs.size());
CHECK_GT(list_length, 0); CHECK_GT(list_length, 0);
CHECK_NOTNULL(output_nodes); CHECK_NOTNULL(output_nodes);
@ -1090,8 +1082,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
int mkl_node_output_slot = 0; int mkl_node_output_slot = 0;
GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
&mkl_node_output_slot); &mkl_node_output_slot);
output_nodes->push_back(NodeBuilder::NodeOut(mkl_node, output_nodes->push_back(
mkl_node_output_slot)); NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
(*input_idx)++; (*input_idx)++;
list_length--; list_length--;
} }
@ -1101,9 +1093,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
// node that we are constructing. An input node could be (1) 'n' // node that we are constructing. An input node could be (1) 'n'
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
// if 'n' is not an Mkl layer. // if 'n' is not an Mkl layer.
void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, void MklLayoutRewritePass::GetNodeProducingMklTensor(
Node* orig_node, Node* n, std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot,
int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) { Node** mkl_node, int* mkl_node_output_slot) {
CHECK_NOTNULL(n); CHECK_NOTNULL(n);
CHECK_NOTNULL(mkl_node); CHECK_NOTNULL(mkl_node);
CHECK_NOTNULL(mkl_node_output_slot); CHECK_NOTNULL(mkl_node_output_slot);
@ -1234,8 +1226,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
if (ArgIsList(arg)) { if (ArgIsList(arg)) {
std::vector<NodeBuilder::NodeOut> new_node_inputs; std::vector<NodeBuilder::NodeOut> new_node_inputs;
int N = GetTensorListLength(arg, old_node); int N = GetTensorListLength(arg, old_node);
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
N, &new_node_inputs); &new_node_inputs);
nb->Input(new_node_inputs); nb->Input(new_node_inputs);
nn_slot_idx++; nn_slot_idx++;
} else { } else {
@ -1355,8 +1347,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
// the same frame. // the same frame.
if (orig_node->num_inputs() > 0) { if (orig_node->num_inputs() > 0) {
Node* orig_input0 = nullptr; Node* orig_input0 = nullptr;
TF_CHECK_OK(orig_node->input_node(0, TF_CHECK_OK(
const_cast<const Node**>(&orig_input0))); orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
} }
@ -1374,7 +1366,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
for (auto ws : wsinfo_) { for (auto ws : wsinfo_) {
if (orig_node->type_string() == ws.fwd_op && if (orig_node->type_string() == ws.fwd_op &&
mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { mkl_op_registry::IsMklOp(
mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
// If this op is a fwd op, then we need to check if there is an // If this op is a fwd op, then we need to check if there is an
// edge from this node's fwd_slot to bwdop's bwd_slot. If there is // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
// an edge, then we just add an attribute on this node for setting // an edge, then we just add an attribute on this node for setting
@ -1400,7 +1393,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
nb->Attr("workspace_enabled", false); nb->Attr("workspace_enabled", false);
} }
} else if (orig_node->type_string() == ws.bwd_op && } else if (orig_node->type_string() == ws.bwd_op &&
mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), mkl_op_registry::IsMklOp(
mkl_op_registry::GetMklOpName(orig_node->type_string()),
T)) { T)) {
// If this op is a bwd op, then we need to add workspace edge and // If this op is a bwd op, then we need to add workspace edge and
// it's Mkl tensor edge between its corresponding fwd op and this // it's Mkl tensor edge between its corresponding fwd op and this
@ -1416,7 +1410,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
if (e->src_output() == ws.fwd_slot && if (e->src_output() == ws.fwd_slot &&
// We would have rewritten the forward op, so we need to use // We would have rewritten the forward op, so we need to use
// GetMklOpName call to get its Mkl name. // GetMklOpName call to get its Mkl name.
e->src()->type_string() == mkl_op_registry::GetMklOpName(ws.fwd_op) && e->src()->type_string() ==
mkl_op_registry::GetMklOpName(ws.fwd_op) &&
e->dst_input() == ws.bwd_slot) { e->dst_input() == ws.bwd_slot) {
nb->Attr("workspace_enabled", true); nb->Attr("workspace_enabled", true);
CHECK_NOTNULL(ws_tensors); CHECK_NOTNULL(ws_tensors);
@ -1869,8 +1864,8 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
if (e->IsControlEdge()) { if (e->IsControlEdge()) {
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
} else { } else {
CHECK_NOTNULL((*g)->AddEdge(new_node, e->src_output(), e->dst(), CHECK_NOTNULL(
e->dst_input())); (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input()));
} }
} }
@ -2012,8 +2007,9 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
if (e->IsControlEdge()) { if (e->IsControlEdge()) {
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
} else { } else {
CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(), CHECK_NOTNULL((*g)->AddEdge(
e->src()->num_outputs()), new_node,
GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
e->dst(), e->dst_input())); e->dst(), e->dst_input()));
} }
} }
@ -2070,7 +2066,8 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
// BiasAddGrad is not an Mkl layer, so we make an exception for it. // BiasAddGrad is not an Mkl layer, so we make an exception for it.
if (n->type_string() != csinfo_.bias_add_grad) { if (n->type_string() != csinfo_.bias_add_grad) {
if (!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), T)) { if (!mkl_op_registry::IsMklOp(
mkl_op_registry::GetMklOpName(n->type_string()), T)) {
return nullptr; return nullptr;
} }
} }
@ -2186,8 +2183,7 @@ bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
return MklLayoutRewritePass().RunPass(g); return MklLayoutRewritePass().RunPass(g);
} }
Status MklLayoutRewritePass::Run( Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
const GraphOptimizationPassOptions& options) {
if (options.graph == nullptr && options.partition_graphs == nullptr) { if (options.graph == nullptr && options.partition_graphs == nullptr) {
return Status::OK(); return Status::OK();
} }
@ -2474,29 +2470,28 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.conv2d, rinfo_.push_back({csinfo_.conv2d,
mkl_op_registry::GetMklOpName(csinfo_.conv2d), mkl_op_registry::GetMklOpName(csinfo_.conv2d),
CopyAttrsConv2D, AlwaysRewrite}); CopyAttrsConv2D, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_with_bias, rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
csinfo_.mkl_conv2d_with_bias,
CopyAttrsConv2D, AlwaysRewrite}); CopyAttrsConv2D, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_filter, rinfo_.push_back({csinfo_.conv2d_grad_filter,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
CopyAttrsConv2D, AlwaysRewrite}); CopyAttrsConv2D, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias, rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
csinfo_.mkl_conv2d_grad_filter_with_bias, csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D,
CopyAttrsConv2D, AlwaysRewrite}); AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_input, rinfo_.push_back({csinfo_.conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
CopyAttrsConv2D, AlwaysRewrite}); CopyAttrsConv2D, AlwaysRewrite});
rinfo_.push_back({csinfo_.fused_batch_norm, rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite}); CopyAttrsFusedBatchNorm, AlwaysRewrite});
rinfo_.push_back({csinfo_.fused_batch_norm_grad, rinfo_.push_back(
{csinfo_.fused_batch_norm_grad,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
CopyAttrsFusedBatchNorm, AlwaysRewrite}); CopyAttrsFusedBatchNorm, AlwaysRewrite});
rinfo_.push_back({csinfo_.identity, rinfo_.push_back({csinfo_.identity,
mkl_op_registry::GetMklOpName(csinfo_.identity), mkl_op_registry::GetMklOpName(csinfo_.identity),
CopyAttrsDataType, AlwaysRewrite}); CopyAttrsDataType, AlwaysRewrite});
rinfo_.push_back({csinfo_.lrn, rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
mkl_op_registry::GetMklOpName(csinfo_.lrn),
CopyAttrsLRN, AlwaysRewrite}); CopyAttrsLRN, AlwaysRewrite});
rinfo_.push_back({csinfo_.lrn_grad, rinfo_.push_back({csinfo_.lrn_grad,
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
@ -2515,8 +2510,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
mkl_op_registry::GetMklOpName(csinfo_.mul), mkl_op_registry::GetMklOpName(csinfo_.mul),
CopyAttrsDataType, AlwaysRewrite}); CopyAttrsDataType, AlwaysRewrite});
*/ */
rinfo_.push_back({csinfo_.relu, rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
mkl_op_registry::GetMklOpName(csinfo_.relu),
CopyAttrsDataType, AlwaysRewrite}); CopyAttrsDataType, AlwaysRewrite});
rinfo_.push_back({csinfo_.relu_grad, rinfo_.push_back({csinfo_.relu_grad,
mkl_op_registry::GetMklOpName(csinfo_.relu_grad), mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
@ -2550,8 +2544,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Add a rule for merging nodes // Add a rule for merging nodes
minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add, minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
csinfo_.conv2d_with_bias, csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
GetConv2DOrBiasAdd});
minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad, minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
csinfo_.conv2d_grad_filter_with_bias, csinfo_.conv2d_grad_filter_with_bias,
@ -2846,9 +2839,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Default rewrite rule to be used in scenario 1 for rewrite. // Default rewrite rule to be used in scenario 1 for rewrite.
// @return - true (since we want to always rewrite) // @return - true (since we want to always rewrite)
static bool AlwaysRewrite(const Node* n) { static bool AlwaysRewrite(const Node* n) { return true; }
return true;
}
// Check if we are performing pooling on depth or batch. If it is, then we // Check if we are performing pooling on depth or batch. If it is, then we
// do not rewrite MaxPool node to Mkl version. // do not rewrite MaxPool node to Mkl version.
@ -2862,8 +2853,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
std::vector<int32> ksize, strides; std::vector<int32> ksize, strides;
CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true);
true);
CHECK_EQ(FormatFromString(data_format_str, &data_format), true); CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
// Condition that specifies non-batch-wise and non-depth-wise pooling. // Condition that specifies non-batch-wise and non-depth-wise pooling.
@ -2941,8 +2931,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// @output output_nodes - the list of new nodes creating Mkl tensors // @output output_nodes - the list of new nodes creating Mkl tensors
// //
// @return None // @return None
void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g, void GetNodesProducingMklTensorList(
Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, std::unique_ptr<Graph>* g, Node* orig_node,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
int* input_idx, int list_length, int* input_idx, int list_length,
std::vector<NodeBuilder::NodeOut>* output_nodes); std::vector<NodeBuilder::NodeOut>* output_nodes);
@ -2961,7 +2952,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// will feed the tensor // will feed the tensor
// @return None // @return None
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node, void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot); Node* n, int n_output_slot, Node** mkl_node,
int* mkl_node_output_slot);
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
@ -3115,8 +3107,8 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
// the same frame. // the same frame.
if (orig_node->num_inputs() > 0) { if (orig_node->num_inputs() > 0) {
Node* orig_input0 = nullptr; Node* orig_input0 = nullptr;
TF_CHECK_OK(orig_node->input_node(0, TF_CHECK_OK(
const_cast<const Node**>(&orig_input0))); orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
// Allow duplicate while adding control edge as it would fail (return // Allow duplicate while adding control edge as it would fail (return
// NULL) if we try to add duplicate edge. // NULL) if we try to add duplicate edge.
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
@ -3126,11 +3118,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
} }
void MklLayoutRewritePass::GetNodesProducingMklTensorList( void MklLayoutRewritePass::GetNodesProducingMklTensorList(
std::unique_ptr<Graph>* g, std::unique_ptr<Graph>* g, Node* orig_node,
Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
int* input_idx, int list_length,
std::vector<NodeBuilder::NodeOut>* output_nodes) {
CHECK_LT(*input_idx, inputs.size()); CHECK_LT(*input_idx, inputs.size());
CHECK_GT(list_length, 0); CHECK_GT(list_length, 0);
CHECK_NOTNULL(output_nodes); CHECK_NOTNULL(output_nodes);
@ -3147,8 +3137,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
int mkl_node_output_slot = 0; int mkl_node_output_slot = 0;
GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
&mkl_node_output_slot); &mkl_node_output_slot);
output_nodes->push_back(NodeBuilder::NodeOut(mkl_node, output_nodes->push_back(
mkl_node_output_slot)); NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
(*input_idx)++; (*input_idx)++;
list_length--; list_length--;
} }
@ -3158,9 +3148,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
// node that we are constructing. An input node could be (1) 'n' // node that we are constructing. An input node could be (1) 'n'
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
// if 'n' is not an Mkl layer. // if 'n' is not an Mkl layer.
void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, void MklLayoutRewritePass::GetNodeProducingMklTensor(
Node* orig_node, Node* n, std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot,
int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) { Node** mkl_node, int* mkl_node_output_slot) {
CHECK_NOTNULL(n); CHECK_NOTNULL(n);
CHECK_NOTNULL(mkl_node); CHECK_NOTNULL(mkl_node);
CHECK_NOTNULL(mkl_node_output_slot); CHECK_NOTNULL(mkl_node_output_slot);
@ -3292,8 +3282,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
if (ArgIsList(arg)) { if (ArgIsList(arg)) {
std::vector<NodeBuilder::NodeOut> new_node_inputs; std::vector<NodeBuilder::NodeOut> new_node_inputs;
int N = GetTensorListLength(arg, old_node); int N = GetTensorListLength(arg, old_node);
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
N, &new_node_inputs); &new_node_inputs);
nb->Input(new_node_inputs); nb->Input(new_node_inputs);
nn_slot_idx++; nn_slot_idx++;
} else { } else {
@ -3413,8 +3403,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
// the same frame. // the same frame.
if (orig_node->num_inputs() > 0) { if (orig_node->num_inputs() > 0) {
Node* orig_input0 = nullptr; Node* orig_input0 = nullptr;
TF_CHECK_OK(orig_node->input_node(0, TF_CHECK_OK(
const_cast<const Node**>(&orig_input0))); orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
// Allow duplicate while adding control edge as it would fail (return // Allow duplicate while adding control edge as it would fail (return
// NULL) if we try to add duplicate edge. // NULL) if we try to add duplicate edge.
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
@ -3434,8 +3424,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
for (auto ws : wsinfo_) { for (auto ws : wsinfo_) {
if (orig_node->type_string() == ws.fwd_op && if (orig_node->type_string() == ws.fwd_op &&
mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName( mkl_op_registry::IsMklOp(
orig_node->type_string()), T)) { mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
// If this op is a fwd op, then we need to check if there is an // If this op is a fwd op, then we need to check if there is an
// edge from this node's fwd_slot to bwdop's bwd_slot. If there is // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
// an edge, then we just add an attribute on this node for setting // an edge, then we just add an attribute on this node for setting
@ -3461,8 +3451,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
nb->Attr("workspace_enabled", false); nb->Attr("workspace_enabled", false);
} }
} else if (orig_node->type_string() == ws.bwd_op && } else if (orig_node->type_string() == ws.bwd_op &&
mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName( mkl_op_registry::IsMklOp(
orig_node->type_string()), T)) { mkl_op_registry::GetMklOpName(orig_node->type_string()),
T)) {
// If this op is a bwd op, then we need to add workspace edge and // If this op is a bwd op, then we need to add workspace edge and
// it's Mkl tensor edge between its corresponding fwd op and this // it's Mkl tensor edge between its corresponding fwd op and this
// op. Corresponding fwd op is specified in 'fwd_op' field of // op. Corresponding fwd op is specified in 'fwd_op' field of
@ -3477,8 +3468,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
if (e->src_output() == ws.fwd_slot && if (e->src_output() == ws.fwd_slot &&
// We would have rewritten the forward op, so we need to use // We would have rewritten the forward op, so we need to use
// GetMklOpName call to get its Mkl name. // GetMklOpName call to get its Mkl name.
e->src()->type_string() == mkl_op_registry::GetMklOpName( e->src()->type_string() ==
ws.fwd_op) && mkl_op_registry::GetMklOpName(ws.fwd_op) &&
e->dst_input() == ws.bwd_slot) { e->dst_input() == ws.bwd_slot) {
nb->Attr("workspace_enabled", true); nb->Attr("workspace_enabled", true);
CHECK_NOTNULL(ws_tensors); CHECK_NOTNULL(ws_tensors);
@ -3777,7 +3768,8 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
CHECK_EQ(((m->type_string() == csinfo_.bias_add && CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
n->type_string() == csinfo_.conv2d)) || n->type_string() == csinfo_.conv2d)) ||
((n->type_string() == csinfo_.bias_add && ((n->type_string() == csinfo_.bias_add &&
m->type_string() == csinfo_.conv2d)), true); m->type_string() == csinfo_.conv2d)),
true);
// If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd, // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
// BiasAdd is successor node, and Conv2D predecessor node. // BiasAdd is successor node, and Conv2D predecessor node.
@ -3796,8 +3788,7 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides)); TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred)); TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ)); TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
TF_CHECK_OK( TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu));
GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu));
// We check to ensure that data formats of both succ and pred are same. // We check to ensure that data formats of both succ and pred are same.
// We expect them to be same, so we can enforce this as assert. // We expect them to be same, so we can enforce this as assert.
// But assert can be too strict, so we enforce this as a check. // But assert can be too strict, so we enforce this as a check.
@ -3900,8 +3891,8 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
// BiasAdd has only 1 output (at slot 0) and merged node also has only 1 // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
// output (at slot 0). // output (at slot 0).
const int kConv2DWithBiasOutputSlot = 0; const int kConv2DWithBiasOutputSlot = 0;
CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(),
e->dst(), e->dst_input())); e->dst_input()));
} }
} }
@ -3925,7 +3916,8 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad && CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
n->type_string() == csinfo_.conv2d_grad_filter)) || n->type_string() == csinfo_.conv2d_grad_filter)) ||
((n->type_string() == csinfo_.bias_add_grad && ((n->type_string() == csinfo_.bias_add_grad &&
m->type_string() == csinfo_.conv2d_grad_filter)), true); m->type_string() == csinfo_.conv2d_grad_filter)),
true);
// If 'm' is BiasAddGrad, then 'n' is BackpropFilter. // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n; Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
@ -4132,8 +4124,9 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
// NULL) if we try to add duplicate edge. // NULL) if we try to add duplicate edge.
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
} else { } else {
CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(), CHECK_NOTNULL((*g)->AddEdge(
e->src()->num_outputs()), new_node,
GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
e->dst(), e->dst_input())); e->dst(), e->dst_input()));
} }
} }
@ -4166,8 +4159,8 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
// names. // names.
if (n->type_string() != csinfo_.conv2d_with_bias && if (n->type_string() != csinfo_.conv2d_with_bias &&
n->type_string() != csinfo_.conv2d_grad_filter_with_bias && n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName( !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
n->type_string()), T)) { T)) {
return nullptr; return nullptr;
} }
@ -4191,22 +4184,23 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
int num_parent = 0; int num_parent = 0;
for (auto parent : n->in_edges()) { for (auto parent : n->in_edges()) {
if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) { if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) {
VLOG(1) << "ELEMENTWISE: parent " << num_parent++ << " is MKL op: " VLOG(1) << "ELEMENTWISE: parent " << num_parent++
<< parent->src()->type_string(); << " is MKL op: " << parent->src()->type_string();
incoming_mkl_edge = true; incoming_mkl_edge = true;
break; break;
} else { } else {
VLOG(1) << "ELEMENTWISE: parent " << num_parent++ << " is NON-MKL op: " VLOG(1) << "ELEMENTWISE: parent " << num_parent++
<< parent->src()->type_string(); << " is NON-MKL op: " << parent->src()->type_string();
} }
} }
if (incoming_mkl_edge == false) { if (incoming_mkl_edge == false) {
VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which has no MKL " VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which "
"has no MKL "
"parents."; "parents.";
return nullptr; return nullptr;
} else { } else {
VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string() << VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string()
" which has MKL parents"; << " which has MKL parents";
} }
} }
@ -4214,8 +4208,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
// for this op, then we rewrite it to Mkl op. // for this op, then we rewrite it to Mkl op.
// Find matching RewriteInfo and then check that rewrite rule applies. // Find matching RewriteInfo and then check that rewrite rule applies.
for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
if (n->type_string().compare(ri->name) == 0 && if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
ri->rewrite_rule(n)) {
return &*ri; return &*ri;
} }
} }
@ -4297,8 +4290,7 @@ bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
return MklLayoutRewritePass().RunPass(g); return MklLayoutRewritePass().RunPass(g);
} }
Status MklLayoutRewritePass::Run( Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
const GraphOptimizationPassOptions& options) {
if (options.graph == nullptr && options.partition_graphs == nullptr) { if (options.graph == nullptr && options.partition_graphs == nullptr) {
return Status::OK(); return Status::OK();
} }

View File

@ -125,8 +125,10 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
REGISTER_OP("_MklInput2").Output("o: uint8") REGISTER_OP("_MklInput2")
.Output("o1: uint8").SetIsStateful(); .Output("o: uint8")
.Output("o1: uint8")
.SetIsStateful();
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
// Unit tests related to node merge optiimization // Unit tests related to node merge optiimization
@ -498,7 +500,6 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) {
"M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
} }
// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) { TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) {
InitGraph( InitGraph(
@ -874,7 +875,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
" input: ['A', 'B:0', 'B:1']}" " input: ['A', 'B:0', 'B:1']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }"); " input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
@ -1273,7 +1275,8 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
"node { name: 'H' op: 'Input'}" "node { name: 'H' op: 'Input'}"
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['H', 'G'] }"); " input: ['H', 'G'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
"I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
@ -1640,7 +1643,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
" attr { key: 'padding' value { s: 'SAME' } }" " attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B']}" " input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }", kGPUDevice); " input: ['B', 'C'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1");
} }
@ -1666,7 +1670,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
"node { name: 'F' op: 'BiasAddGrad'" "node { name: 'F' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['E'] }", kGPUDevice); " input: ['E'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
"E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
@ -1687,7 +1692,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
" attr { key: 'padding' value { s: 'SAME' } }" " attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'C']}" " input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }", kGPUDevice); " input: ['A', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|"
"A->D;A->E;B->D:1;C->D:2;D->E:1"); "A->D;A->E;B->D:1;C->D:2;D->E:1");
@ -1700,7 +1706,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
" attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }" " input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }", kGPUDevice); " input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1");
} }
@ -1713,7 +1720,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
" attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }" " input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }", kGPUDevice); " input: ['A', 'C'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
} }
@ -1729,7 +1737,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }" " input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }", kGPUDevice); " input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
} }
@ -1745,7 +1754,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }" " input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }", kGPUDevice); " input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1");
} }
@ -1766,7 +1776,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
" attr { key: 'N' value { i: 2 } }" " attr { key: 'N' value { i: 2 } }"
" input: ['A', 'B:0', 'B:1']}" " input: ['A', 'B:0', 'B:1']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }", kGPUDevice); " input: ['C', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;"
"B->D:1;B:1->D:2;C->E;D->E:1"); "B->D:1;B:1->D:2;C->E;D->E:1");
@ -1788,7 +1799,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
" attr { key: 'N' value { i: 2 } }" " attr { key: 'N' value { i: 2 } }"
" input: ['B:0', 'B:1', 'A']}" " input: ['B:0', 'B:1', 'A']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }", kGPUDevice); " input: ['C', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
"A->D:2;B->D;B:1->D:1;C->E;D->E:1"); "A->D:2;B->D;B:1->D:1;C->E;D->E:1");
@ -1808,7 +1820,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
" attr { key: 'is_training' value { b: true } }" " attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }" " input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }", kGPUDevice); " input: ['A', 'F'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);E(Input);" "A(Input);B(Input);C(Input);D(Input);E(Input);"
"F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
@ -1837,7 +1850,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
"node { name: 'Y' op: 'Input'}" "node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'" "node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }" " attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}", kGPUDevice); " input: ['E', 'Y']}",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
"M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;"
@ -1972,8 +1986,10 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
REGISTER_OP("_MklInput2").Output("o: uint8") REGISTER_OP("_MklInput2")
.Output("o1: uint8").SetIsStateful(); .Output("o: uint8")
.Output("o1: uint8")
.SetIsStateful();
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
// Unit tests related to node merge optiimization // Unit tests related to node merge optiimization
@ -2492,7 +2508,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
" input: ['A', 'B:0', 'B:1']}" " input: ['A', 'B:0', 'B:1']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }"); " input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
@ -2891,7 +2908,8 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
"node { name: 'H' op: 'Input'}" "node { name: 'H' op: 'Input'}"
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['H', 'G'] }"); " input: ['H', 'G'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
"I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
@ -3258,7 +3276,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
" attr { key: 'padding' value { s: 'SAME' } }" " attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B']}" " input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }", kGPUDevice); " input: ['B', 'C'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1");
} }
@ -3284,7 +3303,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
"node { name: 'F' op: 'BiasAddGrad'" "node { name: 'F' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['E'] }", kGPUDevice); " input: ['E'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
"E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
@ -3305,7 +3325,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
" attr { key: 'padding' value { s: 'SAME' } }" " attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'C']}" " input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }", kGPUDevice); " input: ['A', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|"
"A->D;A->E;B->D:1;C->D:2;D->E:1"); "A->D;A->E;B->D:1;C->D:2;D->E:1");
@ -3318,7 +3339,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
" attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }" " input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }", kGPUDevice); " input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1");
} }
@ -3331,7 +3353,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
" attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }" " input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }", kGPUDevice); " input: ['A', 'C'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
} }
@ -3347,7 +3370,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }" " input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }", kGPUDevice); " input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
} }
@ -3363,7 +3387,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }" " input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }", kGPUDevice); " input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1");
} }
@ -3384,7 +3409,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
" attr { key: 'N' value { i: 2 } }" " attr { key: 'N' value { i: 2 } }"
" input: ['A', 'B:0', 'B:1']}" " input: ['A', 'B:0', 'B:1']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }", kGPUDevice); " input: ['C', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;"
"B->D:1;B:1->D:2;C->E;D->E:1"); "B->D:1;B:1->D:2;C->E;D->E:1");
@ -3406,7 +3432,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
" attr { key: 'N' value { i: 2 } }" " attr { key: 'N' value { i: 2 } }"
" input: ['B:0', 'B:1', 'A']}" " input: ['B:0', 'B:1', 'A']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }", kGPUDevice); " input: ['C', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
"A->D:2;B->D;B:1->D:1;C->E;D->E:1"); "A->D:2;B->D;B:1->D:1;C->E;D->E:1");
@ -3426,7 +3453,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
" attr { key: 'is_training' value { b: true } }" " attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }" " input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }", kGPUDevice); " input: ['A', 'F'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);E(Input);" "A(Input);B(Input);C(Input);D(Input);E(Input);"
"F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
@ -3455,7 +3483,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
"node { name: 'Y' op: 'Input'}" "node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'" "node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }" " attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}", kGPUDevice); " input: ['E', 'Y']}",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(), EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
"M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;"

View File

@ -33,8 +33,8 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
namespace tensorflow { namespace tensorflow {
@ -152,12 +152,12 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
string data_format; string data_format;
TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype)); TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype));
bool dst_dtype_found = GetNodeAttr(dst->def(), "T", &dst_datatype) == bool dst_dtype_found =
Status::OK(); GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK();
// We compare source and destination datatypes only when both are found. // We compare source and destination datatypes only when both are found.
if (dst_dtype_found && (src_datatype != dst_datatype)) { if (dst_dtype_found && (src_datatype != dst_datatype)) {
string err_msg = "T attribute of " + src->name() + " and " + string err_msg = "T attribute of " + src->name() + " and " + dst->name() +
dst->name() + " do not match. Will not insert" + " do not match. Will not insert" +
" MklToTf node in such case."; " MklToTf node in such case.";
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str()); return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
} }
@ -325,11 +325,11 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
// may not be Mkl node. // may not be Mkl node.
DataType src_datatype; DataType src_datatype;
DataType dst_datatype; DataType dst_datatype;
bool src_is_mkl_op = (GetNodeAttr(src->def(), "T", &src_datatype) == bool src_is_mkl_op =
Status::OK() && (GetNodeAttr(src->def(), "T", &src_datatype) == Status::OK() &&
IsMklSupportedOp(src->type_string(), src_datatype)); IsMklSupportedOp(src->type_string(), src_datatype));
bool dst_is_mkl_op = (GetNodeAttr(dst->def(), "T", &dst_datatype) == bool dst_is_mkl_op =
Status::OK() && (GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK() &&
IsMklSupportedOp(dst->type_string(), dst_datatype)); IsMklSupportedOp(dst->type_string(), dst_datatype));
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant. // Check if src with is Mkl-compliant, while dst is not Mkl-compliant.

View File

@ -55,9 +55,8 @@ Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
// use the clock to be destroyed. // use the clock to be destroyed.
std::unique_ptr<Thread> CreateFakeClockAdvancerThread( std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
test_util::FakeClockEnv* env, Notification* start, Notification* stop) { test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
return std::unique_ptr<Thread>( return std::unique_ptr<Thread>(Env::Default()->StartThread(
Env::Default()->StartThread({}, "FakeClockAdvancerThread", {}, "FakeClockAdvancerThread", [env, start, stop] {
[env, start, stop] {
start->WaitForNotification(); start->WaitForNotification();
while (!stop->HasBeenNotified()) { while (!stop->HasBeenNotified()) {
env->AdvanceByMicroseconds(10); env->AdvanceByMicroseconds(10);

View File

@ -92,7 +92,6 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
} }
private: private:
class Iterator : public DatasetIterator<Dataset> { class Iterator : public DatasetIterator<Dataset> {
public: public:
explicit Iterator(const Params& params) explicit Iterator(const Params& params)

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/notification.h"
namespace tensorflow { namespace tensorflow {
/* static */ /* static */
@ -185,8 +184,7 @@ Status CapturedFunction::MaybeInstantiate(
return Status::OK(); return Status::OK();
} }
Status CapturedFunction::Run(IteratorContext* ctx, Status CapturedFunction::Run(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>&& args,
std::vector<Tensor>* rets) { std::vector<Tensor>* rets) {
FunctionLibraryRuntime::Handle handle; FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle)); TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));

View File

@ -128,8 +128,8 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
while (i_ < dataset()->count_) { while (i_ < dataset()->count_) {
// Fetch and throw away Tensors. // Fetch and throw away Tensors.
std::vector<Tensor> dummy_out_tensors; std::vector<Tensor> dummy_out_tensors;
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &dummy_out_tensors, TF_RETURN_IF_ERROR(
end_of_sequence)); input_impl_->GetNext(ctx, &dummy_out_tensors, end_of_sequence));
if (*end_of_sequence) { if (*end_of_sequence) {
// We reached the end before the count was reached. // We reached the end before the count was reached.
input_impl_.reset(); input_impl_.reset();
@ -140,8 +140,8 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
} }
// Return GetNext() on the underlying iterator. // Return GetNext() on the underlying iterator.
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, TF_RETURN_IF_ERROR(
end_of_sequence)); input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (*end_of_sequence) { if (*end_of_sequence) {
input_impl_.reset(); input_impl_.reset();
} }
@ -184,8 +184,7 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
}; };
}; };
REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), SkipDatasetOp);
SkipDatasetOp);
} // namespace } // namespace

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
namespace tensorflow { namespace tensorflow {
namespace fuzzing { namespace fuzzing {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
namespace tensorflow { namespace tensorflow {
namespace fuzzing { namespace fuzzing {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
namespace tensorflow { namespace tensorflow {
namespace fuzzing { namespace fuzzing {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
namespace tensorflow { namespace tensorflow {
namespace fuzzing { namespace fuzzing {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
namespace tensorflow { namespace tensorflow {
namespace fuzzing { namespace fuzzing {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
namespace tensorflow { namespace tensorflow {
namespace fuzzing { namespace fuzzing {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
namespace tensorflow { namespace tensorflow {
namespace fuzzing { namespace fuzzing {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
namespace tensorflow { namespace tensorflow {
namespace fuzzing { namespace fuzzing {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
namespace tensorflow { namespace tensorflow {
namespace fuzzing { namespace fuzzing {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
namespace tensorflow { namespace tensorflow {
namespace fuzzing { namespace fuzzing {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
namespace tensorflow { namespace tensorflow {
namespace fuzzing { namespace fuzzing {

View File

@ -46,7 +46,7 @@ GraphTransferUtils::GetTopNFloatResults(const float* const data,
GetTopNFloatResults(data, labels, element_count); GetTopNFloatResults(data, labels, element_count);
LOG(INFO) << "=== Dump ranking ==="; LOG(INFO) << "=== Dump ranking ===";
for (int i = 0; i < top_n; ++i) { for (int i = 0; i < top_n; ++i) {
const std::tuple<float, int, string> &entry = queue.top(); const std::tuple<float, int, string>& entry = queue.top();
LOG(INFO) << i << ": " << std::get<1>(entry) << ", " << std::get<2>(entry) LOG(INFO) << i << ": " << std::get<1>(entry) << ", " << std::get<2>(entry)
<< ", " << std::get<0>(entry); << ", " << std::get<0>(entry);
queue.pop(); queue.pop();

View File

@ -181,8 +181,8 @@ class GraphTransferer {
void AppendNodeInputParams(const int id, const Node& node, void AppendNodeInputParams(const int id, const Node& node,
const std::vector<int>& extra_inputs); const std::vector<int>& extra_inputs);
void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, const int id,
const int id, const Node& node); const Node& node);
static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray( static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray(
const shape_inference::ShapeHandle& shape_handle, const shape_inference::ShapeHandle& shape_handle,

View File

@ -42,8 +42,7 @@ constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f;
class GraphTransfererTest : public ::testing::Test { class GraphTransfererTest : public ::testing::Test {
protected: protected:
void SetUp() final { void SetUp() final {}
}
GraphTransferer gt_; GraphTransferer gt_;
}; };
@ -61,7 +60,7 @@ class TestGraphTransferOpsDefinitions : public IRemoteFusedGraphOpsDefinitions {
} }
} }
return -1; return -1;
} }
private: private:
const std::vector<string> op_types_{"INPUT", "OUTPUT", "Conv2D", const std::vector<string> op_types_{"INPUT", "OUTPUT", "Conv2D",

View File

@ -71,10 +71,10 @@ class NeonDepthwiseConv2dNativeOp : public BinaryOp<float> {
filter.shape().DebugString())); filter.shape().DebugString()));
const int32 in_depth = input.dim_size(3); const int32 in_depth = input.dim_size(3);
OP_REQUIRES( OP_REQUIRES(context, in_depth == filter.dim_size(2),
context, in_depth == filter.dim_size(2), errors::InvalidArgument(
errors::InvalidArgument("input and filter must have the same depth: ", "input and filter must have the same depth: ", in_depth,
in_depth, " vs ", filter.dim_size(2))); " vs ", filter.dim_size(2)));
const int32 batch = input.dim_size(0); const int32 batch = input.dim_size(0);
const int32 input_rows = input.dim_size(1); const int32 input_rows = input.dim_size(1);
const int32 input_cols = input.dim_size(2); const int32 input_cols = input.dim_size(2);

View File

@ -66,7 +66,9 @@ struct EigenEnvironment {
} }
return Task{ return Task{
std::unique_ptr<TaskImpl>(new TaskImpl{ std::unique_ptr<TaskImpl>(new TaskImpl{
std::move(f), Context(ContextKind::kThread), id, std::move(f),
Context(ContextKind::kThread),
id,
}), }),
}; };
} }

View File

@ -97,8 +97,8 @@ TEST(ThreadPool, ParallelForWithWorkerId) {
} }
pool.ParallelForWithWorkerId( pool.ParallelForWithWorkerId(
kWorkItems, kHugeCost, kWorkItems, kHugeCost,
[&threads_running, &work, num_threads]( [&threads_running, &work, num_threads](int64 begin, int64 end,
int64 begin, int64 end, int64 id) { int64 id) {
// Store true for the current thread, and assert that another thread // Store true for the current thread, and assert that another thread
// is not running with the same id. // is not running with the same id.
ASSERT_LE(0, id); ASSERT_LE(0, id);

View File

@ -18,12 +18,12 @@ limitations under the License.
#include <mutex> #include <mutex>
#include "sqlite3.h" #include "sqlite3.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/lib/core/refcount.h"
/// TensorFlow SQLite Veneer /// TensorFlow SQLite Veneer
/// ///
@ -121,10 +121,7 @@ class LOCKABLE Sqlite : public core::RefCounted {
Sqlite(sqlite3* db, sqlite3_stmt* begin, sqlite3_stmt* commit, Sqlite(sqlite3* db, sqlite3_stmt* begin, sqlite3_stmt* commit,
sqlite3_stmt* rollback) noexcept sqlite3_stmt* rollback) noexcept
: db_(db), : db_(db), begin_(begin), commit_(commit), rollback_(rollback) {}
begin_(begin),
commit_(commit),
rollback_(rollback) {}
sqlite3* const db_; sqlite3* const db_;
sqlite3_stmt* const begin_; sqlite3_stmt* const begin_;
@ -233,7 +230,8 @@ class SqliteStatement {
/// freed until this statement is Reset() or finalized. /// freed until this statement is Reset() or finalized.
void BindText(int parameter, const StringPiece& text) { void BindText(int parameter, const StringPiece& text) {
Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(),
SQLITE_TRANSIENT, SQLITE_UTF8), parameter); SQLITE_TRANSIENT, SQLITE_UTF8),
parameter);
size_ += text.size(); size_ += text.size();
} }
void BindText(const char* parameter, const StringPiece& text) { void BindText(const char* parameter, const StringPiece& text) {
@ -241,7 +239,8 @@ class SqliteStatement {
} }
void BindTextUnsafe(int parameter, const StringPiece& text) { void BindTextUnsafe(int parameter, const StringPiece& text) {
Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(),
SQLITE_STATIC, SQLITE_UTF8), parameter); SQLITE_STATIC, SQLITE_UTF8),
parameter);
size_ += text.size(); size_ += text.size();
} }
void BindTextUnsafe(const char* parameter, const StringPiece& text) { void BindTextUnsafe(const char* parameter, const StringPiece& text) {
@ -254,7 +253,8 @@ class SqliteStatement {
/// freed until this statement is Reset() or finalized. /// freed until this statement is Reset() or finalized.
void BindBlob(int parameter, const StringPiece& blob) { void BindBlob(int parameter, const StringPiece& blob) {
Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(),
SQLITE_TRANSIENT), parameter); SQLITE_TRANSIENT),
parameter);
size_ += blob.size(); size_ += blob.size();
} }
void BindBlob(const char* parameter, const StringPiece& blob) { void BindBlob(const char* parameter, const StringPiece& blob) {
@ -262,7 +262,8 @@ class SqliteStatement {
} }
void BindBlobUnsafe(int parameter, const StringPiece& blob) { void BindBlobUnsafe(int parameter, const StringPiece& blob) {
Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(),
SQLITE_STATIC), parameter); SQLITE_STATIC),
parameter);
size_ += blob.size(); size_ += blob.size();
} }
void BindBlobUnsafe(const char* parameter, const StringPiece& text) { void BindBlobUnsafe(const char* parameter, const StringPiece& text) {
@ -320,9 +321,7 @@ class SqliteStatement {
/// \brief Move constructor, after which <other> is reset to empty. /// \brief Move constructor, after which <other> is reset to empty.
SqliteStatement(SqliteStatement&& other) noexcept SqliteStatement(SqliteStatement&& other) noexcept
: db_(other.db_), : db_(other.db_), stmt_(other.stmt_), bind_error_(other.bind_error_) {
stmt_(other.stmt_),
bind_error_(other.bind_error_) {
other.db_ = nullptr; other.db_ = nullptr;
other.stmt_ = nullptr; other.stmt_ = nullptr;
other.bind_error_ = SQLITE_OK; other.bind_error_ = SQLITE_OK;

View File

@ -33,9 +33,7 @@ class SqliteTest : public ::testing::Test {
db_->PrepareOrDie("CREATE TABLE T (a BLOB, b BLOB)").StepAndResetOrDie(); db_->PrepareOrDie("CREATE TABLE T (a BLOB, b BLOB)").StepAndResetOrDie();
} }
void TearDown() override { void TearDown() override { db_->Unref(); }
db_->Unref();
}
Sqlite* db_; Sqlite* db_;
bool is_done_; bool is_done_;

View File

@ -55,22 +55,21 @@ namespace gtl {
template <typename F> template <typename F>
class Cleanup { class Cleanup {
public: public:
Cleanup() Cleanup() : released_(true), f_() {}
: released_(true), f_() {}
template <typename G> template <typename G>
explicit Cleanup(G&& f) // NOLINT explicit Cleanup(G&& f) // NOLINT
: f_(std::forward<G>(f)) {} // NOLINT(build/c++11) : f_(std::forward<G>(f)) {} // NOLINT(build/c++11)
Cleanup(Cleanup&& src) // NOLINT Cleanup(Cleanup&& src) // NOLINT
: released_(src.is_released()), f_(src.release()) { } : released_(src.is_released()), f_(src.release()) {}
// Implicitly move-constructible from any compatible Cleanup<G>. // Implicitly move-constructible from any compatible Cleanup<G>.
// The source will be released as if src.release() were called. // The source will be released as if src.release() were called.
// A moved-from Cleanup can be safely destroyed or reassigned. // A moved-from Cleanup can be safely destroyed or reassigned.
template <typename G> template <typename G>
Cleanup(Cleanup<G>&& src) // NOLINT Cleanup(Cleanup<G>&& src) // NOLINT
: released_(src.is_released()), f_(src.release()) { } : released_(src.is_released()), f_(src.release()) {}
// Assignment to a Cleanup object behaves like destroying it // Assignment to a Cleanup object behaves like destroying it
// and making a new one in its place, analogous to unique_ptr // and making a new one in its place, analogous to unique_ptr
@ -102,8 +101,8 @@ class Cleanup {
F f_; F f_;
}; };
template <int&... ExplicitParameterBarrier, template <int&... ExplicitParameterBarrier, typename F,
typename F, typename DecayF = typename std::decay<F>::type> typename DecayF = typename std::decay<F>::type>
TF_MUST_USE_RESULT Cleanup<DecayF> MakeCleanup(F&& f) { TF_MUST_USE_RESULT Cleanup<DecayF> MakeCleanup(F&& f) {
return Cleanup<DecayF>(std::forward<F>(f)); return Cleanup<DecayF>(std::forward<F>(f));
} }

View File

@ -65,15 +65,14 @@ TEST(CleanupTest, Release) {
TEST(FinallyTest, TypeErasedWithoutFactory) { TEST(FinallyTest, TypeErasedWithoutFactory) {
string s = "active"; string s = "active";
{ {
AnyCleanup s_cleaner([&s]{ s.append(" clean"); }); AnyCleanup s_cleaner([&s] { s.append(" clean"); });
EXPECT_EQ("active", s); EXPECT_EQ("active", s);
} }
EXPECT_EQ("active clean", s); EXPECT_EQ("active clean", s);
} }
struct Appender { struct Appender {
Appender(string* s, const string& msg) Appender(string* s, const string& msg) : s_(s), msg_(msg) {}
: s_(s), msg_(msg) {}
void operator()() const { s_->append(msg_); } void operator()() const { s_->append(msg_); }
string* s_; string* s_;
string msg_; string msg_;
@ -163,7 +162,12 @@ class CleanupReferenceTest : public ::testing::Test {
int* i; int* i;
F(int* cp, int* i) : cp(cp), i(i) {} F(int* cp, int* i) : cp(cp), i(i) {}
F(const F& o) : cp(o.cp), i(o.i) { ++*cp; } F(const F& o) : cp(o.cp), i(o.i) { ++*cp; }
F& operator=(const F& o) { cp = o.cp; i = o.i; ++*cp; return *this; } F& operator=(const F& o) {
cp = o.cp;
i = o.i;
++*cp;
return *this;
}
F(F&&) = default; F(F&&) = default;
F& operator=(F&&) = default; F& operator=(F&&) = default;
void operator()() const { ++*i; } void operator()() const { ++*i; }
@ -279,7 +283,7 @@ BENCHMARK(BM_AnyCleanup);
void BM_AnyCleanupNoFactory(int iters) { void BM_AnyCleanupNoFactory(int iters) {
while (iters--) { while (iters--) {
AnyCleanup fin([]{Incr();}); AnyCleanup fin([] { Incr(); });
} }
} }
BENCHMARK(BM_AnyCleanupNoFactory); BENCHMARK(BM_AnyCleanupNoFactory);

View File

@ -31,12 +31,12 @@ limitations under the License.
#ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ #ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
#define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ #define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
#include <cstddef>
#include <stddef.h> #include <stddef.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <sys/types.h> #include <sys/types.h>
#include <algorithm> #include <algorithm>
#include <cstddef>
#include <iterator> #include <iterator>
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
@ -407,7 +407,7 @@ class InlinedVector {
}; };
// 2) Construct a T with args at not-yet-initialized memory pointed by dst. // 2) Construct a T with args at not-yet-initialized memory pointed by dst.
struct Construct { struct Construct {
template<class... Args> template <class... Args>
void operator()(T* dst, Args&&... args) const { void operator()(T* dst, Args&&... args) const {
new (dst) T(std::forward<Args>(args)...); new (dst) T(std::forward<Args>(args)...);
} }

View File

@ -255,13 +255,13 @@ class IntType {
value_ op arg_value; \ value_ op arg_value; \
return *this; \ return *this; \
} }
INT_TYPE_ASSIGNMENT_OP(+= ); INT_TYPE_ASSIGNMENT_OP(+=);
INT_TYPE_ASSIGNMENT_OP(-= ); INT_TYPE_ASSIGNMENT_OP(-=);
INT_TYPE_ASSIGNMENT_OP(*= ); INT_TYPE_ASSIGNMENT_OP(*=);
INT_TYPE_ASSIGNMENT_OP(/= ); INT_TYPE_ASSIGNMENT_OP(/=);
INT_TYPE_ASSIGNMENT_OP(<<= ); // NOLINT INT_TYPE_ASSIGNMENT_OP(<<=); // NOLINT
INT_TYPE_ASSIGNMENT_OP(>>= ); // NOLINT INT_TYPE_ASSIGNMENT_OP(>>=); // NOLINT
INT_TYPE_ASSIGNMENT_OP(%= ); INT_TYPE_ASSIGNMENT_OP(%=);
#undef INT_TYPE_ASSIGNMENT_OP #undef INT_TYPE_ASSIGNMENT_OP
ThisType& operator=(ValueType arg_value) { ThisType& operator=(ValueType arg_value) {
@ -314,10 +314,10 @@ std::ostream& operator<<(std::ostream& os, // NOLINT
INT_TYPE_ARITHMETIC_OP(+); INT_TYPE_ARITHMETIC_OP(+);
INT_TYPE_ARITHMETIC_OP(-); INT_TYPE_ARITHMETIC_OP(-);
INT_TYPE_ARITHMETIC_OP(*); INT_TYPE_ARITHMETIC_OP(*);
INT_TYPE_ARITHMETIC_OP(/ ); INT_TYPE_ARITHMETIC_OP(/);
INT_TYPE_ARITHMETIC_OP(<< ); // NOLINT INT_TYPE_ARITHMETIC_OP(<<); // NOLINT
INT_TYPE_ARITHMETIC_OP(>> ); // NOLINT INT_TYPE_ARITHMETIC_OP(>>); // NOLINT
INT_TYPE_ARITHMETIC_OP(% ); INT_TYPE_ARITHMETIC_OP(%);
#undef INT_TYPE_ARITHMETIC_OP #undef INT_TYPE_ARITHMETIC_OP
// -- NON-MEMBER COMPARISON OPERATORS ------------------------------------------ // -- NON-MEMBER COMPARISON OPERATORS ------------------------------------------
@ -345,12 +345,12 @@ INT_TYPE_ARITHMETIC_OP(% );
IntType<IntTypeName, ValueType> id) { \ IntType<IntTypeName, ValueType> id) { \
return val op id.value(); \ return val op id.value(); \
} }
INT_TYPE_COMPARISON_OP(== ); // NOLINT INT_TYPE_COMPARISON_OP(==); // NOLINT
INT_TYPE_COMPARISON_OP(!= ); // NOLINT INT_TYPE_COMPARISON_OP(!=); // NOLINT
INT_TYPE_COMPARISON_OP(< ); // NOLINT INT_TYPE_COMPARISON_OP(<); // NOLINT
INT_TYPE_COMPARISON_OP(<= ); // NOLINT INT_TYPE_COMPARISON_OP(<=); // NOLINT
INT_TYPE_COMPARISON_OP(> ); // NOLINT INT_TYPE_COMPARISON_OP(>); // NOLINT
INT_TYPE_COMPARISON_OP(>= ); // NOLINT INT_TYPE_COMPARISON_OP(>=); // NOLINT
#undef INT_TYPE_COMPARISON_OP #undef INT_TYPE_COMPARISON_OP
} // namespace gtl } // namespace gtl

View File

@ -42,7 +42,8 @@ class IntTypeTest : public ::testing::Test {
// All tests below will be executed on all supported IntTypes. // All tests below will be executed on all supported IntTypes.
typedef ::testing::Types<Int8_IT, UInt8_IT, Int16_IT, UInt16_IT, Int32_IT, typedef ::testing::Types<Int8_IT, UInt8_IT, Int16_IT, UInt16_IT, Int32_IT,
Int64_IT, UInt64_IT, Long_IT> SupportedIntTypes; Int64_IT, UInt64_IT, Long_IT>
SupportedIntTypes;
TYPED_TEST_CASE(IntTypeTest, SupportedIntTypes); TYPED_TEST_CASE(IntTypeTest, SupportedIntTypes);
@ -232,7 +233,8 @@ TYPED_TEST(IntTypeTest, TestOperators) {
TYPED_TEST(IntTypeTest, TestHashFunctor) { TYPED_TEST(IntTypeTest, TestHashFunctor) {
std::unordered_map<typename TestFixture::T, char, std::unordered_map<typename TestFixture::T, char,
typename TestFixture::T::Hasher> map; typename TestFixture::T::Hasher>
map;
typename TestFixture::T a(0); typename TestFixture::T a(0);
map[a] = 'c'; map[a] = 'c';
EXPECT_EQ('c', map[a]); EXPECT_EQ('c', map[a]);

View File

@ -593,12 +593,12 @@ class optional : private internal_optional::optional_data<T>,
assert(this->engaged_); assert(this->engaged_);
return this->pointer(); return this->pointer();
} }
constexpr const T& operator*() const & { return reference(); } constexpr const T& operator*() const& { return reference(); }
T& operator*() & { T& operator*() & {
assert(this->engaged_); assert(this->engaged_);
return reference(); return reference();
} }
constexpr const T&& operator*() const && { return std::move(reference()); } constexpr const T&& operator*() const&& { return std::move(reference()); }
T&& operator*() && { T&& operator*() && {
assert(this->engaged_); assert(this->engaged_);
return std::move(reference()); return std::move(reference());
@ -621,7 +621,7 @@ class optional : private internal_optional::optional_data<T>,
// Use `opt.value()` to get a reference to underlying value. The constness // Use `opt.value()` to get a reference to underlying value. The constness
// and lvalue/rvalue-ness of `opt` is preserved to the view of the T // and lvalue/rvalue-ness of `opt` is preserved to the view of the T
// subobject. // subobject.
const T& value() const & { const T& value() const& {
CHECK(*this) << "Bad optional access"; CHECK(*this) << "Bad optional access";
return reference(); return reference();
} }
@ -633,7 +633,7 @@ class optional : private internal_optional::optional_data<T>,
CHECK(*this) << "Bad optional access"; CHECK(*this) << "Bad optional access";
return std::move(reference()); return std::move(reference());
} }
const T&& value() const && { // NOLINT(build/c++11) const T&& value() const&& { // NOLINT(build/c++11)
CHECK(*this) << "Bad optional access"; CHECK(*this) << "Bad optional access";
return std::move(reference()); return std::move(reference());
} }
@ -641,7 +641,7 @@ class optional : private internal_optional::optional_data<T>,
// Use `opt.value_or(val)` to get either the value of T or the given default // Use `opt.value_or(val)` to get either the value of T or the given default
// `val` in the empty case. // `val` in the empty case.
template <class U> template <class U>
constexpr T value_or(U&& v) const & { constexpr T value_or(U&& v) const& {
return static_cast<bool>(*this) ? **this return static_cast<bool>(*this) ? **this
: static_cast<T>(std::forward<U>(v)); : static_cast<T>(std::forward<U>(v));
} }
@ -656,8 +656,8 @@ class optional : private internal_optional::optional_data<T>,
constexpr const T& reference() const { return *this->pointer(); } constexpr const T& reference() const { return *this->pointer(); }
T& reference() { return *(this->pointer()); } T& reference() { return *(this->pointer()); }
// T constraint checks. You can't have an optional of nullopt_t, in_place_t or // T constraint checks. You can't have an optional of nullopt_t, in_place_t
// a reference. // or a reference.
static_assert( static_assert(
!std::is_same<nullopt_t, typename std::remove_cv<T>::type>::value, !std::is_same<nullopt_t, typename std::remove_cv<T>::type>::value,
"optional<nullopt_t> is not allowed."); "optional<nullopt_t> is not allowed.");

View File

@ -24,17 +24,29 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace { namespace {
using tensorflow::gtl::optional;
using tensorflow::gtl::nullopt;
using tensorflow::gtl::nullopt_t;
using tensorflow::gtl::in_place; using tensorflow::gtl::in_place;
using tensorflow::gtl::in_place_t; using tensorflow::gtl::in_place_t;
using tensorflow::gtl::make_optional; using tensorflow::gtl::make_optional;
using tensorflow::gtl::nullopt;
using tensorflow::gtl::nullopt_t;
using tensorflow::gtl::optional;
template <typename T> string TypeQuals(T&) { return "&"; } template <typename T>
template <typename T> string TypeQuals(T&&) { return "&&"; } string TypeQuals(T&) {
template <typename T> string TypeQuals(const T&) { return "c&"; } return "&";
template <typename T> string TypeQuals(const T&&) { return "c&&"; } }
template <typename T>
string TypeQuals(T&&) {
return "&&";
}
template <typename T>
string TypeQuals(const T&) {
return "c&";
}
template <typename T>
string TypeQuals(const T&&) {
return "c&&";
}
struct StructorListener { struct StructorListener {
int construct0 = 0; int construct0 = 0;

View File

@ -28,10 +28,10 @@ limitations under the License.
namespace { namespace {
using tensorflow::string;
using tensorflow::gtl::TopN; using tensorflow::gtl::TopN;
using tensorflow::random::PhiloxRandom; using tensorflow::random::PhiloxRandom;
using tensorflow::random::SimplePhilox; using tensorflow::random::SimplePhilox;
using tensorflow::string;
// Move the contents from an owned raw pointer, returning by value. // Move the contents from an owned raw pointer, returning by value.
// Objects are easier to manage by value. // Objects are easier to manage by value.

View File

@ -22,6 +22,6 @@ namespace compression {
const char kNone[] = ""; const char kNone[] = "";
const char kGzip[] = "GZIP"; const char kGzip[] = "GZIP";
} } // namespace compression
} } // namespace io
} } // namespace tensorflow

View File

@ -23,8 +23,8 @@ namespace compression {
extern const char kNone[]; extern const char kNone[];
extern const char kGzip[]; extern const char kGzip[];
} } // namespace compression
} } // namespace io
} } // namespace tensorflow
#endif // TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ #endif // TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_

View File

@ -207,7 +207,7 @@ Status RecordReader::SkipNBytes(uint64 offset) {
} }
} }
return Status::OK(); return Status::OK();
} } // namespace io
SequentialRecordReader::SequentialRecordReader( SequentialRecordReader::SequentialRecordReader(
RandomAccessFile* file, const RecordReaderOptions& options) RandomAccessFile* file, const RecordReaderOptions& options)

View File

@ -218,8 +218,8 @@ TEST_F(RecordioTest, RandomRead) {
// Tests of all the error paths in log_reader.cc follow: // Tests of all the error paths in log_reader.cc follow:
static void AssertHasSubstr(StringPiece s, StringPiece expected) { static void AssertHasSubstr(StringPiece s, StringPiece expected) {
EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain " EXPECT_TRUE(StringPiece(s).contains(expected))
<< expected; << s << " does not contain " << expected;
} }
TEST_F(RecordioTest, ReadError) { TEST_F(RecordioTest, ReadError) {

View File

@ -197,8 +197,8 @@ bool CommonInitDecode(StringPiece png_string, int desired_channels,
int desired_channel_bits, DecodeContext* context) { int desired_channel_bits, DecodeContext* context) {
CHECK(desired_channel_bits == 8 || desired_channel_bits == 16) CHECK(desired_channel_bits == 8 || desired_channel_bits == 16)
<< "desired_channel_bits = " << desired_channel_bits; << "desired_channel_bits = " << desired_channel_bits;
CHECK(0 <= desired_channels && desired_channels <= 4) << "desired_channels = " CHECK(0 <= desired_channels && desired_channels <= 4)
<< desired_channels; << "desired_channels = " << desired_channels;
context->error_condition = false; context->error_condition = false;
context->channels = desired_channels; context->channels = desired_channels;
context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context, context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context,

View File

@ -35,8 +35,8 @@ void FillRandoms(PhiloxRandom gen, typename Distribution::ResultElementType* p,
int64 size) { int64 size) {
const int granularity = Distribution::kResultElementCount; const int granularity = Distribution::kResultElementCount;
CHECK(size % granularity == 0) << " size: " << size CHECK(size % granularity == 0)
<< " granularity: " << granularity; << " size: " << size << " granularity: " << granularity;
Distribution dist; Distribution dist;
for (int i = 0; i < size; i += granularity) { for (int i = 0; i < size; i += granularity) {

View File

@ -17,8 +17,8 @@ limitations under the License.
#define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ #define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
#define _USE_MATH_DEFINES #define _USE_MATH_DEFINES
#include <cmath>
#include <math.h> #include <math.h>
#include <cmath>
#undef _USE_MATH_DEFINES #undef _USE_MATH_DEFINES
#include <string.h> #include <string.h>
@ -27,7 +27,6 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/philox_random.h"
namespace tensorflow { namespace tensorflow {
namespace random { namespace random {

View File

@ -45,8 +45,8 @@ void FillRandomsWithSingles(PhiloxRandom gen,
int64 size) { int64 size) {
int granularity = Distribution::kResultElementCount; int granularity = Distribution::kResultElementCount;
CHECK(size % granularity == 0) << " size: " << size CHECK(size % granularity == 0)
<< " granularity: " << granularity; << " size: " << size << " granularity: " << granularity;
SingleSampleAdapter<PhiloxRandom> single_samples(&gen); SingleSampleAdapter<PhiloxRandom> single_samples(&gen);

View File

@ -472,7 +472,8 @@ void OrderedCode::WriteSignedNumIncreasing(string* dest, int64 val) {
// buf = val in network byte order, sign extended to 10 bytes // buf = val in network byte order, sign extended to 10 bytes
const char sign_byte = val < 0 ? '\xff' : '\0'; const char sign_byte = val < 0 ? '\xff' : '\0';
char buf[10] = { char buf[10] = {
sign_byte, sign_byte, sign_byte,
sign_byte,
}; };
StoreBigEndian64(buf + 2, val); StoreBigEndian64(buf + 2, val);
static_assert(sizeof(buf) == kMaxSigned64Length, "max length size mismatch"); static_assert(sizeof(buf) == kMaxSigned64Length, "max length size mismatch");

View File

@ -25,8 +25,9 @@ namespace tensorflow {
namespace { namespace {
TEST(BackwardsCompatibilityTest, IsCompatible) { TEST(BackwardsCompatibilityTest, IsCompatible) {
OpCompatibilityLib compatibility( OpCompatibilityLib compatibility("tensorflow/core/ops",
"tensorflow/core/ops", strings::StrCat("v", TF_MAJOR_VERSION), nullptr); strings::StrCat("v", TF_MAJOR_VERSION),
nullptr);
Env* env = Env::Default(); Env* env = Env::Default();
int changed_ops = 0; int changed_ops = 0;

View File

@ -18,9 +18,9 @@ limitations under the License.
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netdb.h> #include <netdb.h>
#else #else
#include <Windows.h>
#include <winsock2.h> #include <winsock2.h>
#include <ws2tcpip.h> #include <ws2tcpip.h>
#include <Windows.h>
#endif #endif
#include <sys/types.h> #include <sys/types.h>

View File

@ -38,8 +38,7 @@ class FakeHttpRequest : public CurlHttpRequest {
public: public:
/// Return the response for the given request. /// Return the response for the given request.
FakeHttpRequest(const string& request, const string& response) FakeHttpRequest(const string& request, const string& response)
: FakeHttpRequest(request, response, Status::OK(), nullptr, {}, 200) { : FakeHttpRequest(request, response, Status::OK(), nullptr, {}, 200) {}
}
/// Return the response with headers for the given request. /// Return the response with headers for the given request.
FakeHttpRequest(const string& request, const string& response, FakeHttpRequest(const string& request, const string& response,

View File

@ -160,10 +160,10 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) {
ASSERT_EQ(1, EVP_DigestVerifyInit(md_ctx, nullptr, md, nullptr, key)); ASSERT_EQ(1, EVP_DigestVerifyInit(md_ctx, nullptr, md, nullptr, key));
ASSERT_EQ(1, EVP_DigestVerifyUpdate(md_ctx, header_dot_claim.c_str(), ASSERT_EQ(1, EVP_DigestVerifyUpdate(md_ctx, header_dot_claim.c_str(),
header_dot_claim.size())); header_dot_claim.size()));
ASSERT_EQ( ASSERT_EQ(1,
1,
EVP_DigestVerifyFinal( EVP_DigestVerifyFinal(
md_ctx, const_cast<unsigned char*>( md_ctx,
const_cast<unsigned char*>(
reinterpret_cast<const unsigned char*>(signature.data())), reinterpret_cast<const unsigned char*>(signature.data())),
signature.size())); signature.size()));
EVP_MD_CTX_cleanup(md_ctx); EVP_MD_CTX_cleanup(md_ctx);

View File

@ -25,7 +25,6 @@ namespace tensorflow {
namespace { namespace {
class RetryingRandomAccessFile : public RandomAccessFile { class RetryingRandomAccessFile : public RandomAccessFile {
public: public:
RetryingRandomAccessFile(std::unique_ptr<RandomAccessFile> base_file, RetryingRandomAccessFile(std::unique_ptr<RandomAccessFile> base_file,

View File

@ -27,8 +27,7 @@ TEST(CudaLibdevicePathTest, LibdevicePath) {
VLOG(2) << "Libdevice root = " << LibdeviceRoot(); VLOG(2) << "Libdevice root = " << LibdeviceRoot();
std::vector<string> libdevice_files; std::vector<string> libdevice_files;
TF_EXPECT_OK(Env::Default()->GetMatchingPaths( TF_EXPECT_OK(Env::Default()->GetMatchingPaths(
io::JoinPath(LibdeviceRoot(), "libdevice.*.bc"), io::JoinPath(LibdeviceRoot(), "libdevice.*.bc"), &libdevice_files));
&libdevice_files));
EXPECT_LT(0, libdevice_files.size()); EXPECT_LT(0, libdevice_files.size());
} }
#endif #endif

View File

@ -579,8 +579,10 @@ Status DeviceTracerImpl::Collect(StepStatsCollector *collector) {
// TODO(pbar) Handle device IDs and prefix properly. // TODO(pbar) Handle device IDs and prefix properly.
const string prefix = ""; const string prefix = "";
const int id = 0; const int id = 0;
const string stream_device = strings::StrCat(prefix, "/device:GPU:", id, "/stream:"); const string stream_device =
const string memcpy_device = strings::StrCat(prefix, "/device:GPU:", id, "/memcpy"); strings::StrCat(prefix, "/device:GPU:", id, "/stream:");
const string memcpy_device =
strings::StrCat(prefix, "/device:GPU:", id, "/memcpy");
mutex_lock l2(trace_mu_); mutex_lock l2(trace_mu_);
for (const auto &rec : kernel_records_) { for (const auto &rec : kernel_records_) {

View File

@ -91,7 +91,6 @@ void LogMessage::GenerateLogMessage() {
} }
#endif #endif
namespace { namespace {
// Parse log level (int64) from environment variable (char*) // Parse log level (int64) from environment variable (char*)

View File

@ -19,8 +19,8 @@ limitations under the License.
// IWYU pragma: private, include "third_party/tensorflow/core/platform/logging.h" // IWYU pragma: private, include "third_party/tensorflow/core/platform/logging.h"
// IWYU pragma: friend third_party/tensorflow/core/platform/logging.h // IWYU pragma: friend third_party/tensorflow/core/platform/logging.h
#include <sstream>
#include <limits> #include <limits>
#include <sstream>
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -205,16 +205,18 @@ string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) {
inline string* name##Impl(int v1, int v2, const char* exprtext) { \ inline string* name##Impl(int v1, int v2, const char* exprtext) { \
return name##Impl<int, int>(v1, v2, exprtext); \ return name##Impl<int, int>(v1, v2, exprtext); \
} \ } \
inline string* name##Impl(const size_t v1, const int v2, const char* exprtext) { \ inline string* name##Impl(const size_t v1, const int v2, \
const char* exprtext) { \
if (TF_PREDICT_FALSE(v2 < 0)) { \ if (TF_PREDICT_FALSE(v2 < 0)) { \
return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext);\ return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext); \
} \ } \
const size_t uval = (size_t)((unsigned)v1); \ const size_t uval = (size_t)((unsigned)v1); \
return name##Impl<size_t, size_t>(uval, v2, exprtext); \ return name##Impl<size_t, size_t>(uval, v2, exprtext); \
} \ } \
inline string* name##Impl(const int v1, const size_t v2, const char* exprtext) { \ inline string* name##Impl(const int v1, const size_t v2, \
const char* exprtext) { \
if (TF_PREDICT_FALSE(v2 >= std::numeric_limits<int>::max())) { \ if (TF_PREDICT_FALSE(v2 >= std::numeric_limits<int>::max())) { \
return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext);\ return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext); \
} \ } \
const size_t uval = (size_t)((unsigned)v2); \ const size_t uval = (size_t)((unsigned)v2); \
return name##Impl<size_t, size_t>(v1, uval, exprtext); \ return name##Impl<size_t, size_t>(v1, uval, exprtext); \
@ -225,12 +227,12 @@ string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) {
// This happens if, for example, those are used as token names in a // This happens if, for example, those are used as token names in a
// yacc grammar. // yacc grammar.
TF_DEFINE_CHECK_OP_IMPL(Check_EQ, TF_DEFINE_CHECK_OP_IMPL(Check_EQ,
== ) // Compilation error with CHECK_EQ(NULL, x)? ==) // Compilation error with CHECK_EQ(NULL, x)?
TF_DEFINE_CHECK_OP_IMPL(Check_NE, != ) // Use CHECK(x == NULL) instead. TF_DEFINE_CHECK_OP_IMPL(Check_NE, !=) // Use CHECK(x == NULL) instead.
TF_DEFINE_CHECK_OP_IMPL(Check_LE, <= ) TF_DEFINE_CHECK_OP_IMPL(Check_LE, <=)
TF_DEFINE_CHECK_OP_IMPL(Check_LT, < ) TF_DEFINE_CHECK_OP_IMPL(Check_LT, <)
TF_DEFINE_CHECK_OP_IMPL(Check_GE, >= ) TF_DEFINE_CHECK_OP_IMPL(Check_GE, >=)
TF_DEFINE_CHECK_OP_IMPL(Check_GT, > ) TF_DEFINE_CHECK_OP_IMPL(Check_GT, >)
#undef TF_DEFINE_CHECK_OP_IMPL #undef TF_DEFINE_CHECK_OP_IMPL
// In optimized mode, use CheckOpString to hint to compiler that // In optimized mode, use CheckOpString to hint to compiler that

View File

@ -41,8 +41,8 @@ namespace tensorflow {
namespace port { namespace port {
ScopedFlushDenormal::ScopedFlushDenormal() { ScopedFlushDenormal::ScopedFlushDenormal() {
// For now, we flush denormals only on SSE 3. Other architectures such as ARM // For now, we flush denormals only on SSE 3. Other architectures such as ARM
// can be added as needed. // can be added as needed.
#ifdef DENORM_USE_INTRINSICS #ifdef DENORM_USE_INTRINSICS
if (TestCPUFeature(SSE3)) { if (TestCPUFeature(SSE3)) {

View File

@ -77,7 +77,8 @@ class DeviceTracerTest : public ::testing::Test {
Node* y_neg = test::graph::Unary(&graph, "Neg", i); Node* y_neg = test::graph::Unary(&graph, "Neg", i);
y_neg_ = y_neg->name(); y_neg_ = y_neg->name();
y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); y_neg->set_assigned_device_name(
"/job:localhost/replica:0/task:0/device:GPU:0");
test::graph::ToGraphDef(&graph, &def_); test::graph::ToGraphDef(&graph, &def_);
} }

View File

@ -353,6 +353,7 @@ class EnvWrapper : public Env {
const string& version) override { const string& version) override {
return target_->FormatLibraryFileName(name, version); return target_->FormatLibraryFileName(name, version);
} }
private: private:
Env* target_; Env* target_;
}; };

View File

@ -131,14 +131,15 @@ Status FileSystem::GetMatchingPaths(const string& pattern,
if (children.empty()) continue; if (children.empty()) continue;
// This IsDirectory call can be expensive for some FS. Parallelizing it. // This IsDirectory call can be expensive for some FS. Parallelizing it.
children_dir_status.resize(children.size()); children_dir_status.resize(children.size());
ForEach(0, children.size(), [this, &current_dir, &children, &fixed_prefix, ForEach(0, children.size(),
[this, &current_dir, &children, &fixed_prefix,
&children_dir_status](int i) { &children_dir_status](int i) {
const string child_path = io::JoinPath(current_dir, children[i]); const string child_path = io::JoinPath(current_dir, children[i]);
// In case the child_path doesn't start with the fixed_prefix then // In case the child_path doesn't start with the fixed_prefix then
// we don't need to explore this path. // we don't need to explore this path.
if (!StringPiece(child_path).starts_with(fixed_prefix)) { if (!StringPiece(child_path).starts_with(fixed_prefix)) {
children_dir_status[i] = children_dir_status[i] = Status(tensorflow::error::CANCELLED,
Status(tensorflow::error::CANCELLED, "Operation not needed"); "Operation not needed");
} else { } else {
children_dir_status[i] = IsDirectory(child_path); children_dir_status[i] = IsDirectory(child_path);
} }

View File

@ -20,7 +20,8 @@ limitations under the License.
#if defined(PLATFORM_GOOGLE) #if defined(PLATFORM_GOOGLE)
#include "tensorflow/core/platform/google/build_config/gif.h" #include "tensorflow/core/platform/google/build_config/gif.h"
#elif defined(PLATFORM_POSIX)|| defined(PLATFORM_WINDOWS) ||defined(PLATFORM_POSIX_ANDROID) #elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \
defined(PLATFORM_POSIX_ANDROID)
#include <gif_lib.h> #include <gif_lib.h>
#else #else
#error Define the appropriate PLATFORM_<foo> macro for this platform #error Define the appropriate PLATFORM_<foo> macro for this platform

View File

@ -164,8 +164,9 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) {
} else { } else {
hdfs_->hdfsBuilderSetNameNode(builder, nn.c_str()); hdfs_->hdfsBuilderSetNameNode(builder, nn.c_str());
} }
// KERB_TICKET_CACHE_PATH will be deleted in the future, Because KRB5CCNAME is the build in // KERB_TICKET_CACHE_PATH will be deleted in the future, Because KRB5CCNAME is
// environment variable of Kerberos, so KERB_TICKET_CACHE_PATH and related code are unnecessary. // the build in environment variable of Kerberos, so KERB_TICKET_CACHE_PATH
// and related code are unnecessary.
char* ticket_cache_path = getenv("KERB_TICKET_CACHE_PATH"); char* ticket_cache_path = getenv("KERB_TICKET_CACHE_PATH");
if (ticket_cache_path != nullptr) { if (ticket_cache_path != nullptr) {
hdfs_->hdfsBuilderSetKerbTicketCachePath(builder, ticket_cache_path); hdfs_->hdfsBuilderSetKerbTicketCachePath(builder, ticket_cache_path);

View File

@ -20,7 +20,8 @@ limitations under the License.
#if defined(PLATFORM_GOOGLE) #if defined(PLATFORM_GOOGLE)
#include "tensorflow/core/platform/google/build_config/jpeg.h" #include "tensorflow/core/platform/google/build_config/jpeg.h"
#elif defined(PLATFORM_POSIX)|| defined(PLATFORM_WINDOWS) ||defined(PLATFORM_POSIX_ANDROID) #elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \
defined(PLATFORM_POSIX_ANDROID)
#include <stddef.h> #include <stddef.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>

View File

@ -20,7 +20,8 @@ limitations under the License.
#if defined(PLATFORM_GOOGLE) #if defined(PLATFORM_GOOGLE)
#include "tensorflow/core/platform/google/build_config/png.h" #include "tensorflow/core/platform/google/build_config/png.h"
#elif defined(PLATFORM_POSIX)|| defined(PLATFORM_WINDOWS) ||defined(PLATFORM_POSIX_ANDROID) #elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \
defined(PLATFORM_POSIX_ANDROID)
#include <png.h> #include <png.h>
#else #else
#error Define the appropriate PLATFORM_<foo> macro for this platform #error Define the appropriate PLATFORM_<foo> macro for this platform

View File

@ -58,8 +58,8 @@ class AndroidArmV7ACpuUtilsHelper : public ICpuUtilsHelper {
TF_DISALLOW_COPY_AND_ASSIGN(AndroidArmV7ACpuUtilsHelper); TF_DISALLOW_COPY_AND_ASSIGN(AndroidArmV7ACpuUtilsHelper);
}; };
} // profile_utils } // namespace profile_utils
} // tensorflow } // namespace tensorflow
#endif // defined(__ANDROID__) && (__ANDROID_API__ >= 21) && #endif // defined(__ANDROID__) && (__ANDROID_API__ >= 21) &&
// (defined(__ARM_ARCH_7A__) || defined(__aarch64__)) // (defined(__ARM_ARCH_7A__) || defined(__aarch64__))

View File

@ -28,13 +28,15 @@ namespace profile_utils {
static ICpuUtilsHelper* cpu_utils_helper_instance_ = nullptr; static ICpuUtilsHelper* cpu_utils_helper_instance_ = nullptr;
#if (defined(__powerpc__) || defined(__ppc__) && ( __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || (defined(__s390x__)) #if (defined(__powerpc__) || \
/* static */ uint64 CpuUtils::GetCycleCounterFrequency() { defined(__ppc__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || \
(defined(__s390x__))
/* static */ uint64 CpuUtils::GetCycleCounterFrequency() {
static const uint64 cpu_frequency = GetCycleCounterFrequencyImpl(); static const uint64 cpu_frequency = GetCycleCounterFrequencyImpl();
return cpu_frequency; return cpu_frequency;
} }
#else #else
/* static */ int64 CpuUtils::GetCycleCounterFrequency() { /* static */ int64 CpuUtils::GetCycleCounterFrequency() {
static const int64 cpu_frequency = GetCycleCounterFrequencyImpl(); static const int64 cpu_frequency = GetCycleCounterFrequencyImpl();
return cpu_frequency; return cpu_frequency;
} }

View File

@ -94,14 +94,16 @@ class CpuUtils {
#endif #endif
} }
// Return cycle counter frequency. // Return cycle counter frequency.
// As this method caches the cpu frequency internally, // As this method caches the cpu frequency internally,
// the first call will incur overhead, but not subsequent calls. // the first call will incur overhead, but not subsequent calls.
#if (defined(__powerpc__) || defined(__ppc__) && ( __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || (defined(__s390x__)) #if (defined(__powerpc__) || \
defined(__ppc__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || \
(defined(__s390x__))
static uint64 GetCycleCounterFrequency(); static uint64 GetCycleCounterFrequency();
#else #else
static int64 GetCycleCounterFrequency(); static int64 GetCycleCounterFrequency();
#endif #endif
// Return micro secound per each clock // Return micro secound per each clock
// As this method caches the cpu frequency internally, // As this method caches the cpu frequency internally,

View File

@ -53,15 +53,17 @@ TEST_F(CpuUtilsTest, CheckGetCurrentClockCycle) {
} }
TEST_F(CpuUtilsTest, CheckCycleCounterFrequency) { TEST_F(CpuUtilsTest, CheckCycleCounterFrequency) {
#if (defined(__powerpc__) || defined(__ppc__) && ( __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || (defined(__s390x__)) #if (defined(__powerpc__) || \
defined(__ppc__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || \
(defined(__s390x__))
const uint64 cpu_frequency = CpuUtils::GetCycleCounterFrequency(); const uint64 cpu_frequency = CpuUtils::GetCycleCounterFrequency();
CHECK_GT(cpu_frequency, 0); CHECK_GT(cpu_frequency, 0);
CHECK_NE(cpu_frequency, unsigned(CpuUtils::INVALID_FREQUENCY)); CHECK_NE(cpu_frequency, unsigned(CpuUtils::INVALID_FREQUENCY));
#else #else
const int64 cpu_frequency = CpuUtils::GetCycleCounterFrequency(); const int64 cpu_frequency = CpuUtils::GetCycleCounterFrequency();
CHECK_GT(cpu_frequency, 0); CHECK_GT(cpu_frequency, 0);
CHECK_NE(cpu_frequency, CpuUtils::INVALID_FREQUENCY); CHECK_NE(cpu_frequency, CpuUtils::INVALID_FREQUENCY);
#endif #endif
if (DBG) { if (DBG) {
LOG(INFO) << "Cpu frequency = " << cpu_frequency; LOG(INFO) << "Cpu frequency = " << cpu_frequency;
} }

View File

@ -47,7 +47,7 @@ class ICpuUtilsHelper {
TF_DISALLOW_COPY_AND_ASSIGN(ICpuUtilsHelper); TF_DISALLOW_COPY_AND_ASSIGN(ICpuUtilsHelper);
}; };
} // profile_utils } // namespace profile_utils
} // tensorflow } // namespace tensorflow
#endif // TENSORFLOW_PLATFORM_PROFILEUTILS_I_CPU_UTILS_HELPER_H__ #endif // TENSORFLOW_PLATFORM_PROFILEUTILS_I_CPU_UTILS_HELPER_H__

Some files were not shown because too many files have changed in this diff Show More