From c6d1d34cf28662542df45584422ce2e3311e0d13 Mon Sep 17 00:00:00 2001 From: Doe Hyun Yoon Date: Mon, 7 Dec 2020 10:17:24 -0800 Subject: [PATCH] Refactor op level cost estimator -- cost functions report raw data (e.g., num ops, num bytes, not time), and then PredictCosts() translates it to time (in Costs). (1) Add a new structure, NodeCosts -- this is supposed to be used within op_level_cost_estimator; not for the users of OpLevelCostEstimator. (2) PredictCosts calls PredictNodeCosts, and then convert NodeCosts to Costs; users of OpLevelCostEstimator wouldn't see any difference. (3) The signature of Predict methods for each op type is Status Predict***(OpContext&, NodeCost*); within OpLevelCostEstimator, we'll use Status for handling erroneous cases. (4) Fixed PredictSoftmax(): previously, it incorrectly checking input is rank-2, but it can be any rank >=1. (5) Predict times for fused ops are changed (in unit test, 2ns at most); that's because we now add bytes (in int64) and then calculate time, whereas previously, we first calculate time for each op, and then add them, but bytes to time may introduces some errors (int to float), the current approach is more accurate (however small delta it is). (6) CropAndResize op cost ignored 2nd, 3rd, and 4th input tensors' mem cost; it's now incorporated. PiperOrigin-RevId: 346121141 Change-Id: I6caf1123f99dac6897f048644222f2fb46417885 --- tensorflow/core/grappler/costs/BUILD | 1 + .../grappler/costs/op_level_cost_estimator.cc | 782 ++++++++++-------- .../grappler/costs/op_level_cost_estimator.h | 175 +++- .../costs/op_level_cost_estimator_test.cc | 41 +- 4 files changed, 624 insertions(+), 375 deletions(-) diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 8bda3f60913..951a78d9bac 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -334,6 +334,7 @@ cc_library( "@com_google_absl//absl/strings", "//third_party/eigen3", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler/clusters:utils", ] + tf_protos_grappler(), diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 086f1e99b97..0b28dc3b18c 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -26,10 +26,12 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/utils.h" #include "tensorflow/core/grappler/costs/op_context.h" #include "tensorflow/core/grappler/costs/utils.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace grappler { +// TODO(dyoon): update op to Predict method map for TF ops with V2 or V3 suffix. constexpr int kOpsPerMac = 2; constexpr char kGuaranteeConst[] = "GuaranteeConst"; constexpr char kAddN[] = "AddN"; @@ -121,6 +123,7 @@ constexpr char kAssignAddVariableOp[] = "AssignAddVariableOp"; constexpr char kAssignSubVariableOp[] = "AssignSubVariableOp"; static const Costs::Duration kMinComputeTime(1); +static const int64 kMinComputeOp = 1; namespace { @@ -354,11 +357,12 @@ TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, OpLevelCostEstimator::OpLevelCostEstimator() { // Syntactic sugar to build and return a lambda that takes an OpInfo and // returns a cost. - typedef Costs (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context) - const; - auto wrap = [this](CostImpl impl) -> std::function { - return [this, impl](const OpContext& op_context) { - return (this->*impl)(op_context); + typedef Status (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context, + NodeCosts*) const; + auto wrap = [this](CostImpl impl) + -> std::function { + return [this, impl](const OpContext& op_context, NodeCosts* node_costs) { + return (this->*impl)(op_context, node_costs); }; }; @@ -642,27 +646,72 @@ OpLevelCostEstimator::OpLevelCostEstimator() { } Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const { + Costs costs; + NodeCosts node_costs; + if (PredictNodeCosts(op_context, &node_costs).ok()) { + if (node_costs.has_costs) { + return node_costs.costs; + } + // Convert NodeCosts to Costs. + if (node_costs.minimum_cost_op) { + // Override to minimum cost; Note that some ops with minimum cost may have + // non-typical device (e.g., channel for _Send), which may fail with + // GetDeviceInfo(), called from PredictOpCountBasedCost(). Make sure we + // directly set minimum values to Costs here, not calling + // PredictOpCountBasedCost(). + costs.compute_time = kMinComputeTime; + costs.execution_time = kMinComputeTime; + costs.memory_time = 0; + costs.intermediate_memory_time = 0; + costs.intermediate_memory_read_time = 0; + costs.intermediate_memory_write_time = 0; + } else { + // Convert NodeCosts to Costs. + costs = PredictOpCountBasedCost( + node_costs.num_compute_ops, node_costs.num_total_read_bytes(), + node_costs.num_total_write_bytes(), op_context.op_info); + } + VLOG(1) << "Operation " << op_context.op_info.op() << " takes " + << costs.execution_time.count() << " ns."; + // Copy additional stats from NodeCosts to Costs. + costs.max_memory = node_costs.max_memory; + costs.persistent_memory = node_costs.persistent_memory; + costs.temporary_memory = node_costs.temporary_memory; + costs.inaccurate = node_costs.inaccurate; + costs.num_ops_with_unknown_shapes = + node_costs.num_nodes_with_unknown_shapes; + costs.num_ops_total = node_costs.num_nodes; + return costs; + } + // Errors during node cost estimate. + LOG(WARNING) << "Error in PredictCost() for the op: " + << op_context.op_info.ShortDebugString(); + costs = Costs::ZeroCosts(/*inaccurate=*/true); + costs.num_ops_with_unknown_shapes = node_costs.num_nodes_with_unknown_shapes; + return costs; +} + +Status OpLevelCostEstimator::PredictNodeCosts(const OpContext& op_context, + NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; auto it = device_cost_impl_.find(op_info.op()); if (it != device_cost_impl_.end()) { - std::function estimator = it->second; - Costs costs = estimator(op_context); - VLOG(1) << "Operation " << op_info.op() << " takes " - << costs.execution_time.count() << " ns."; - return costs; + std::function estimator = it->second; + return estimator(op_context, node_costs); } if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) { - return PredictVariable(op_context); + return PredictVariable(op_context, node_costs); } if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) { - return PredictCwiseOp(op_context); + return PredictCwiseOp(op_context, node_costs); } VLOG(1) << "Missing accurate estimator for op: " << op_info.op(); - return PredictCostOfAnUnknownOp(op_context); + node_costs->num_nodes_with_unknown_op_type = 1; + return PredictCostOfAnUnknownOp(op_context, node_costs); } DeviceInfo OpLevelCostEstimator::GetDeviceInfo( @@ -716,7 +765,8 @@ DeviceInfo OpLevelCostEstimator::GetDeviceInfo( return DeviceInfo(gflops, gb_per_sec); } -Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context, + NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; bool found_unknown_shapes = false; // For element-wise operations, op count is the element count of any input. We @@ -736,30 +786,25 @@ Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const { } int op_cost = 1; - bool is_known_elementwise_op = false; auto it = elementwise_ops_.find(op_info.op()); if (it != elementwise_ops_.end()) { op_cost = it->second; - is_known_elementwise_op = true; } else { - LOG(WARNING) << "Not a cwise op: " << op_info.op(); + return errors::InvalidArgument("Not a cwise op: ", op_info.op()); } - Costs costs = PredictOpCountBasedCost(op_count * op_cost, op_info); - if (found_unknown_shapes || !is_known_elementwise_op) { - costs.inaccurate = true; - } - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - return costs; + return PredictDefaultNodeCosts(op_count * op_cost, op_context, + &found_unknown_shapes, node_costs); } -Costs OpLevelCostEstimator::PredictCostOfAnUnknownOp( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictCostOfAnUnknownOp( + const OpContext& op_context, NodeCosts* node_costs) const { // Don't assume the operation is cwise, return cost based on input/output size // and admit that it is inaccurate... - auto costs = PredictOpCountBasedCost(0, op_context.op_info); - costs.inaccurate = true; - return costs; + bool found_unknown_shapes = false; + node_costs->inaccurate = true; + return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes, + node_costs); } Costs OpLevelCostEstimator::PredictOpCountBasedCost( @@ -1509,6 +1554,17 @@ int64 OpLevelCostEstimator::CalculateInputSize(const OpInfo& op_info, return total_input_size; } +std::vector OpLevelCostEstimator::CalculateInputTensorSize( + const OpInfo& op_info, bool* found_unknown_shapes) { + std::vector input_tensor_size; + input_tensor_size.reserve(op_info.inputs().size()); + for (auto& input : op_info.inputs()) { + input_tensor_size.push_back( + CalculateTensorSize(input, found_unknown_shapes)); + } + return input_tensor_size; +} + int64 OpLevelCostEstimator::CalculateLargestInputCount( const OpInfo& op_info, bool* found_unknown_shapes) { int64 largest_input_count = 0; @@ -1527,7 +1583,7 @@ int64 OpLevelCostEstimator::CalculateLargestInputCount( int64 OpLevelCostEstimator::CalculateOutputSize(const OpInfo& op_info, bool* found_unknown_shapes) { int64 total_output_size = 0; - // use float as default for calculations + // Use float as default for calculations. for (const auto& output : op_info.outputs()) { DataType dt = output.dtype(); const auto& original_output_shape = output.shape(); @@ -1545,6 +1601,43 @@ int64 OpLevelCostEstimator::CalculateOutputSize(const OpInfo& op_info, return total_output_size; } +std::vector OpLevelCostEstimator::CalculateOutputTensorSize( + const OpInfo& op_info, bool* found_unknown_shapes) { + std::vector output_tensor_size; + output_tensor_size.reserve(op_info.outputs().size()); + // Use float as default for calculations. + for (const auto& output : op_info.outputs()) { + DataType dt = output.dtype(); + const auto& original_output_shape = output.shape(); + int64 output_size = DataTypeSize(BaseType(dt)); + int num_dims = std::max(1, original_output_shape.dim_size()); + auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims, + found_unknown_shapes); + for (const auto& dim : output_shape.dim()) { + output_size *= dim.size(); + } + output_tensor_size.push_back(output_size); + } + return output_tensor_size; +} + +Status OpLevelCostEstimator::PredictDefaultNodeCosts( + const int64 num_compute_ops, const OpContext& op_context, + bool* found_unknown_shapes, NodeCosts* node_costs) { + const auto& op_info = op_context.op_info; + node_costs->num_compute_ops = num_compute_ops; + node_costs->num_input_bytes_accessed = + CalculateInputTensorSize(op_info, found_unknown_shapes); + node_costs->num_output_bytes_accessed = + CalculateOutputTensorSize(op_info, found_unknown_shapes); + node_costs->max_memory = node_costs->num_total_output_bytes(); + if (*found_unknown_shapes) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return Status::OK(); +} + bool HasZeroDim(const OpInfo& op_info) { for (int i = 0; i < op_info.inputs_size(); ++i) { const auto& input = op_info.inputs(i); @@ -1560,62 +1653,54 @@ bool HasZeroDim(const OpInfo& op_info) { return false; } -Costs OpLevelCostEstimator::PredictConv2D(const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictConv2D(const OpContext& op_context, + NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; if (HasZeroDim(op_info)) { - Costs costs = Costs::ZeroCosts(); - costs.inaccurate = true; - costs.num_ops_with_unknown_shapes = 1; - return costs; + node_costs->num_nodes_with_unknown_shapes = 1; + return errors::InvalidArgument("Conv2D op includes zero dimension: ", + op_info.ShortDebugString()); } bool found_unknown_shapes = false; - auto costs = PredictOpCountBasedCost( - CountConv2DOperations(op_info, &found_unknown_shapes), op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - return costs; + int64 num_compute_ops = CountConv2DOperations(op_info, &found_unknown_shapes); + return PredictDefaultNodeCosts(num_compute_ops, op_context, + &found_unknown_shapes, node_costs); } -Costs OpLevelCostEstimator::PredictConv2DBackpropInput( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictConv2DBackpropInput( + const OpContext& op_context, NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; if (HasZeroDim(op_info)) { - Costs costs = Costs::ZeroCosts(); - costs.inaccurate = true; - costs.num_ops_with_unknown_shapes = true; - return costs; + node_costs->num_nodes_with_unknown_shapes = 1; + return errors::InvalidArgument( + "Conv2DBackpropInput op includes zero dimension", + op_info.ShortDebugString()); } bool found_unknown_shapes = false; - auto costs = - PredictOpCountBasedCost(CountConv2DBackpropInputOperations( - op_info, nullptr, &found_unknown_shapes), - op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - return costs; + int64 num_compute_ops = CountConv2DBackpropInputOperations( + op_info, nullptr, &found_unknown_shapes); + return PredictDefaultNodeCosts(num_compute_ops, op_context, + &found_unknown_shapes, node_costs); } -Costs OpLevelCostEstimator::PredictConv2DBackpropFilter( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictConv2DBackpropFilter( + const OpContext& op_context, NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; if (HasZeroDim(op_info)) { - Costs costs = Costs::ZeroCosts(); - costs.inaccurate = true; - costs.num_ops_with_unknown_shapes = true; - return costs; + node_costs->num_nodes_with_unknown_shapes = 1; + return errors::InvalidArgument( + "Conv2DBackpropFilter op includes zero dimension", + op_info.ShortDebugString()); } bool found_unknown_shapes = false; - auto costs = - PredictOpCountBasedCost(CountConv2DBackpropFilterOperations( - op_info, nullptr, &found_unknown_shapes), - op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - return costs; + int64 num_compute_ops = CountConv2DBackpropFilterOperations( + op_info, nullptr, &found_unknown_shapes); + return PredictDefaultNodeCosts(num_compute_ops, op_context, + &found_unknown_shapes, node_costs); } -Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictFusedConv2DBiasActivation( + const OpContext& op_context, NodeCosts* node_costs) const { // FusedConv2DBiasActivation computes a fused kernel which implements: // 2D convolution, adds side input with separate scaling on convolution and // side inputs, then adds bias, and finally applies the ReLU activation @@ -1639,18 +1724,16 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation( std::string data_format = GetDataFormat(op_context.op_info); if (data_format != "NCHW" && data_format != "NHWC" && data_format != "NCHW_VECT_C") { - LOG(WARNING) << "unsupported data format: " << data_format; - Costs cost = Costs::ZeroCosts(); - cost.inaccurate = true; - return cost; + return errors::InvalidArgument( + "Unsupported data format (", data_format, + ") for op: ", op_context.op_info.ShortDebugString()); } std::string filter_format = GetFilterFormat(op_context.op_info); if (filter_format != "HWIO" && filter_format != "OIHW" && filter_format != "OIHW_VECT_I") { - LOG(WARNING) << "unsupported filter format: " << filter_format; - Costs cost = Costs::ZeroCosts(); - cost.inaccurate = true; - return cost; + return errors::InvalidArgument( + "Unsupported filter format (", filter_format, + ") for op: ", op_context.op_info.ShortDebugString()); } auto& conv_input = op_context.op_info.inputs(0); @@ -1695,42 +1778,48 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation( *op_context_with_output.op_info.mutable_outputs()->Add() = output; // Construct component operations and run the cost computation. - auto costs = PredictFusedOp(op_context_with_output, component_ops); - costs.inaccurate |= found_unknown_shapes; - costs.num_ops_with_unknown_shapes = costs.inaccurate; - return costs; + if (found_unknown_shapes) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return PredictFusedOp(op_context_with_output, component_ops, node_costs); } -Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictMatMul(const OpContext& op_context, + NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; bool found_unknown_shapes = false; - auto costs = PredictOpCountBasedCost( - CountMatMulOperations(op_info, &found_unknown_shapes), op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - return costs; + int64 num_compute_ops = CountMatMulOperations(op_info, &found_unknown_shapes); + return PredictDefaultNodeCosts(num_compute_ops, op_context, + &found_unknown_shapes, node_costs); } -Costs OpLevelCostEstimator::PredictEinsum(const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictEinsum(const OpContext& op_context, + NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; auto it = op_info.attr().find("equation"); - if (it == op_info.attr().end()) return Costs::ZeroCosts(/*inaccurate=*/true); + if (it == op_info.attr().end()) { + return errors::InvalidArgument("Einsum op doesn't have equation attr: ", + op_info.ShortDebugString()); + } + OpContext batch_matmul_op_context; bool found_unknown_shapes = false; bool success = GenerateBatchMatmulContextFromEinsum( op_context, &batch_matmul_op_context, &found_unknown_shapes); - if (!success) { - return PredictCostOfAnUnknownOp(op_context); + if (found_unknown_shapes) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; } - Costs costs = PredictCosts(batch_matmul_op_context); - costs.inaccurate = costs.inaccurate || found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - return costs; + if (!success) { + return PredictCostOfAnUnknownOp(op_context, node_costs); + } + return PredictNodeCosts(batch_matmul_op_context, node_costs); } -Costs OpLevelCostEstimator::PredictSparseTensorDenseMatMul( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictSparseTensorDenseMatMul( + const OpContext& op_context, NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; bool found_unknown_shapes = false; // input[0]: indices in sparse matrix a @@ -1758,93 +1847,113 @@ Costs OpLevelCostEstimator::PredictSparseTensorDenseMatMul( CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes); int64 b_input_size = num_elems_in_a * n_dim * DataTypeSize(BaseType(b_matrix.dtype())); - double input_size = a_indices_input_size + a_values_input_size + - a_shape_input_size + b_input_size; + int64 output_size = CalculateOutputSize(op_info, &found_unknown_shapes); - double output_size = CalculateOutputSize(op_info, &found_unknown_shapes); - - auto costs = - PredictOpCountBasedCost(op_count, input_size, output_size, op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - costs.max_memory = output_size; - - return costs; + node_costs->num_compute_ops = op_count; + node_costs->num_input_bytes_accessed = {a_indices_input_size, + a_values_input_size, + a_shape_input_size, b_input_size}; + node_costs->num_output_bytes_accessed = {output_size}; + if (found_unknown_shapes) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return Status::OK(); } -Costs OpLevelCostEstimator::PredictNoOp(const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictNoOp(const OpContext& op_context, + NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)"; - return Costs::ZeroCosts(); + // By default, NodeCosts is initialized to zero ops and bytes. + return Status::OK(); } -Costs OpLevelCostEstimator::PredictPureMemoryOp( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictPureMemoryOp(const OpContext& op_context, + NodeCosts* node_costs) const { // Each output element is a copy of some element from input, with no required // computation, so just compute memory costs. - return PredictOpCountBasedCost(0, op_context.op_info); + bool found_unknown_shapes = false; + node_costs->num_nodes_with_pure_memory_op = 1; + return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes, + node_costs); } -Costs OpLevelCostEstimator::PredictIdentity(const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictIdentity(const OpContext& op_context, + NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; - VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)"; - Costs result = Costs::ZeroCosts(); - result.max_memory = CalculateOutputSize(op_info, &result.inaccurate); - result.num_ops_with_unknown_shapes = result.inaccurate; - // Assign the minimum amount of time we can represent to the identity op since - // it tends to be really cheap. - result.compute_time = kMinComputeTime; - result.execution_time = result.compute_time; - return result; + VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Identity"; + node_costs->minimum_cost_op = true; + node_costs->num_compute_ops = kMinComputeOp; + // Identity op internally pass input tensor buffer's pointer to the output + // tensor buffer; no actual memory operation. + node_costs->num_input_bytes_accessed = {0}; + node_costs->num_output_bytes_accessed = {0}; + bool inaccurate = false; + node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate); + if (inaccurate) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return Status::OK(); } -Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictVariable(const OpContext& op_context, + NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; - VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)"; - Costs result = Costs::ZeroCosts(); - result.persistent_memory = CalculateOutputSize(op_info, &result.inaccurate); - result.num_ops_with_unknown_shapes = result.inaccurate; - - result.compute_time = kMinComputeTime; - result.execution_time = result.compute_time; - return result; + VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Variable"; + node_costs->minimum_cost_op = true; + node_costs->num_compute_ops = kMinComputeOp; + // Variables are persistent ops; initialized before step; hence, no memory + // cost. + node_costs->num_input_bytes_accessed = {0}; + node_costs->num_output_bytes_accessed = {0}; + bool inaccurate = false; + node_costs->persistent_memory = CalculateOutputSize(op_info, &inaccurate); + if (inaccurate) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return Status::OK(); } -Costs OpLevelCostEstimator::PredictBatchMatMul( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictBatchMatMul(const OpContext& op_context, + NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; bool found_unknown_shapes = false; - Costs costs = PredictOpCountBasedCost( - CountBatchMatMulOperations(op_info, &found_unknown_shapes), op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - return costs; + int64 num_compute_ops = + CountBatchMatMulOperations(op_info, &found_unknown_shapes); + return PredictDefaultNodeCosts(num_compute_ops, op_context, + &found_unknown_shapes, node_costs); } -Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictMetadata(const OpContext& op_context, + NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; - Costs costs = Costs::ZeroCosts(); - costs.max_memory = CalculateOutputSize(op_info, &costs.inaccurate); - costs.num_ops_with_unknown_shapes = costs.inaccurate; - // Metadata operations are so cheap we assume they take the minimum amount of - // time we can represent (1 ns). - costs.compute_time = kMinComputeTime; - costs.execution_time = costs.compute_time; - - return costs; + node_costs->minimum_cost_op = true; + node_costs->num_compute_ops = kMinComputeOp; + node_costs->num_input_bytes_accessed = {0}; + node_costs->num_output_bytes_accessed = {0}; + bool inaccurate = false; + node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate); + if (inaccurate) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return Status::OK(); } -Costs OpLevelCostEstimator::PredictGatherOrSlice( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictGatherOrSlice(const OpContext& op_context, + NodeCosts* node_costs) const { // Gather & Slice ops can have a very large input, but only access a small // part of it. For these op the size of the output determines the memory cost. const auto& op_info = op_context.op_info; const int inputs_needed = op_info.op() == "Slice" ? 3 : 2; if (op_info.outputs_size() == 0 || op_info.inputs_size() < inputs_needed) { - Costs costs = Costs::ZeroCosts(); - costs.inaccurate = true; - return costs; + return errors::InvalidArgument( + op_info.op(), + " Op doesn't have valid input / output: ", op_info.ShortDebugString()); } bool unknown_shapes = false; @@ -1853,10 +1962,19 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice( // For roofline estimate we assume each copy has a unit cost. const int64 op_count = CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes); + node_costs->num_compute_ops = op_count; - const double output_size = CalculateOutputSize(op_info, &unknown_shapes); - double input_size = output_size; - int begin_input_index = 1, end_input_index; + const int64 output_size = CalculateOutputSize(op_info, &unknown_shapes); + node_costs->num_output_bytes_accessed = {output_size}; + + node_costs->num_input_bytes_accessed.reserve(op_info.inputs().size()); + int64 input_size = output_size; + // Note that input(0) byte accessed is not equal to input(0) tensor size. + // It's equal to the output size; though, input access is indexed gather or + // slice (ignore duplicate indices). + node_costs->num_input_bytes_accessed.push_back(input_size); + int begin_input_index = 1; + int end_input_index; if (op_info.op() == "Slice") { // Slice: 'input' (omitted), 'begin', 'size' end_input_index = 3; @@ -1868,20 +1986,18 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice( end_input_index = 2; } for (int i = begin_input_index; i < end_input_index; ++i) { - input_size += - CalculateTensorElementCount(op_info.inputs(i), &unknown_shapes); + node_costs->num_input_bytes_accessed.push_back( + CalculateTensorElementCount(op_info.inputs(i), &unknown_shapes)); } - - Costs costs = - PredictOpCountBasedCost(op_count, input_size, output_size, op_info); - costs.inaccurate = unknown_shapes; - costs.num_ops_with_unknown_shapes = unknown_shapes; - costs.max_memory = output_size; - - return costs; + if (unknown_shapes) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return Status::OK(); } -Costs OpLevelCostEstimator::PredictScatter(const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictScatter(const OpContext& op_context, + NodeCosts* node_costs) const { // Scatter ops sparsely access a reference input and output tensor. const auto& op_info = op_context.op_info; bool found_unknown_shapes = false; @@ -1904,6 +2020,7 @@ Costs OpLevelCostEstimator::PredictScatter(const OpContext& op_context) const { num_elems_in_ref_per_index *= ref_tensor_shape.dim(i).size(); } const int64 op_count = num_indices * num_elems_in_ref_per_index; + node_costs->num_compute_ops = op_count; // Sparsely access ref so input size depends on the number of operations int64 ref_input_size = @@ -1912,44 +2029,50 @@ Costs OpLevelCostEstimator::PredictScatter(const OpContext& op_context) const { CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes); int64 updates_input_size = CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes); - - double total_input_size = - ref_input_size + indices_input_size + updates_input_size; + node_costs->num_input_bytes_accessed = {ref_input_size, indices_input_size, + updates_input_size}; // Sparsely access ref so output size depends on the number of operations - double total_output_size = + int64 output_size = op_count * DataTypeSize(BaseType(op_info.outputs(0).dtype())); + node_costs->num_output_bytes_accessed = {output_size}; - auto costs = PredictOpCountBasedCost(op_count, total_input_size, - total_output_size, op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - - return costs; + if (found_unknown_shapes) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return Status::OK(); } -Costs OpLevelCostEstimator::PredictFusedOp( +Status OpLevelCostEstimator::PredictFusedOp( const OpContext& op_context, - const std::vector& fused_op_contexts) const { - // Note that PredictOpCountBasedCost will get the correct memory_time from + const std::vector& fused_op_contexts, + NodeCosts* node_costs) const { + // Note that PredictDefaultNodeCosts will get the correct memory costs from // the node's inputs and outputs; but we don't want to have to re-implement // the logic for computing the operation count of each of our component // operations here; so we simply add the compute times of each component - // operation, then update the execution time. - Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info); + // operation, then update the cost. + bool found_unknown_shapes = false; + Status s = + PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes, node_costs); - fused_cost.compute_time = 0; - fused_cost.inaccurate = false; for (auto& fused_op : fused_op_contexts) { - auto op_cost = PredictCosts(fused_op); - - fused_cost.compute_time += op_cost.compute_time; - fused_cost.inaccurate |= op_cost.inaccurate; - fused_cost.intermediate_memory_time += op_cost.intermediate_memory_time; + NodeCosts fused_node_costs; + s.Update(PredictNodeCosts(fused_op, &fused_node_costs)); + node_costs->num_compute_ops += fused_node_costs.num_compute_ops; + node_costs->inaccurate |= fused_node_costs.inaccurate; + // Set, not increment. Note that we are predicting the cost of one fused + // node, not a function node composed of many nodes. + node_costs->num_nodes_with_unknown_shapes |= + fused_node_costs.num_nodes_with_unknown_shapes; + node_costs->num_nodes_with_unknown_op_type |= + fused_node_costs.num_nodes_with_unknown_op_type; + node_costs->num_nodes_with_pure_memory_op |= + fused_node_costs.num_nodes_with_pure_memory_op; } - CombineCostsAndUpdateExecutionTime(compute_memory_overlap_, &fused_cost); - return fused_cost; + return Status::OK(); } /* static */ @@ -2040,7 +2163,8 @@ OpLevelCostEstimator::OpDimensionsFromInputs( return conv_dims; } -Costs OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context, + NodeCosts* node_costs) const { bool found_unknown_shapes = false; const auto& op_info = op_context.op_info; // x: op_info.inputs(0) @@ -2050,38 +2174,41 @@ Costs OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context) const { // or 1 copy per output (kx * k1 = 1). int per_output_ops = dims.kx * dims.ky == 1 ? 1 : dims.kx * dims.ky - 1; int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * per_output_ops; + node_costs->num_compute_ops = ops; - double total_input_size = 0; + int64 input_size = 0; if (dims.ky >= dims.sy) { - total_input_size = - CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes); + input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes); } else { // dims.ky < dims.sy // Vertical stride is larger than vertical kernel; assuming row-major // format, skip unnecessary rows (or read every kx rows per sy rows, as the // others are not used for output). const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype())); - total_input_size = - data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz; + input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz; } - const double total_output_size = - CalculateOutputSize(op_info, &found_unknown_shapes); - - Costs costs = PredictOpCountBasedCost(ops, total_input_size, - total_output_size, op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - costs.max_memory = total_output_size; - return costs; + node_costs->num_input_bytes_accessed = {input_size}; + const int64 output_size = CalculateOutputSize(op_info, &found_unknown_shapes); + node_costs->num_output_bytes_accessed = {output_size}; + node_costs->max_memory = output_size; + if (found_unknown_shapes) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return Status::OK(); } -Costs OpLevelCostEstimator::PredictMaxPoolGrad( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictMaxPoolGrad(const OpContext& op_context, + NodeCosts* node_costs) const { bool found_unknown_shapes = false; const auto& op_info = op_context.op_info; // x: op_info.inputs(0) // y: op_info.inputs(1) // y_grad: op_info.inputs(2) - if (op_info.inputs_size() < 3) return Costs::ZeroCosts(/*inaccurate=*/true); + if (op_info.inputs_size() < 3) { + return errors::InvalidArgument("MaxPoolGrad op has invalid inputs: ", + op_info.ShortDebugString()); + } + ConvolutionDimensions dims = OpDimensionsFromInputs( op_info.inputs(0).shape(), op_info, &found_unknown_shapes); @@ -2099,48 +2226,62 @@ Costs OpLevelCostEstimator::PredictMaxPoolGrad( ops = dims.batch * dims.iz * (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy * 2); } + node_costs->num_compute_ops = ops; // Just read x and y_grad; no need to read y as we assume MaxPoolGrad re-run // MaxPool internally. - double total_input_size = + const int64 input0_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes); - total_input_size += + const int64 input2_size = CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes); + node_costs->num_input_bytes_accessed = {input0_size, 0, input2_size}; // Write x_grad; size equal to x. - const double total_output_size = + const int64 output_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes); + node_costs->num_output_bytes_accessed = {output_size}; + node_costs->max_memory = output_size; - Costs costs = PredictOpCountBasedCost(ops, total_input_size, - total_output_size, op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - costs.max_memory = total_output_size; - return costs; + if (found_unknown_shapes) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return Status::OK(); } /* This predict function handles three types of tensorflow ops * AssignVariableOp/AssignAddVariableOp/AssignSubVariableOp, broadcasting * was not possible for these ops, therefore the input tensor's shapes is * enough to compute the cost */ -Costs OpLevelCostEstimator::PredictAssignVariableOps( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictAssignVariableOps( + const OpContext& op_context, NodeCosts* node_costs) const { bool found_unknown_shapes = false; const auto& op_info = op_context.op_info; /* First input of these ops are reference to the assignee. */ - if (op_info.inputs_size() != 2) return Costs::ZeroCosts(true); - const double total_input_size = - CalculateInputSize(op_info, &found_unknown_shapes); - const double flops = op_info.op() == kAssignVariableOp - ? 0.0 - : CalculateTensorElementCount(op_info.inputs(1), - &found_unknown_shapes); - Costs costs = PredictOpCountBasedCost(flops, total_input_size, 0, op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - return costs; + if (op_info.inputs_size() != 2) { + return errors::InvalidArgument("AssignVariable op has invalid input: ", + op_info.ShortDebugString()); + } + + const int64 ops = op_info.op() == kAssignVariableOp + ? 0 + : CalculateTensorElementCount(op_info.inputs(1), + &found_unknown_shapes); + node_costs->num_compute_ops = ops; + const int64 input_size = CalculateInputSize(op_info, &found_unknown_shapes); + node_costs->num_input_bytes_accessed = {input_size}; + // TODO(dyoon): check these ops' behavior whether it writes data; + // Op itself doesn't have output tensor, but it may modify the input (ref or + // resource). Maybe use node_costs->internal_write_bytes. + node_costs->num_output_bytes_accessed = {0}; + if (found_unknown_shapes) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return Status::OK(); } -Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context, + NodeCosts* node_costs) const { bool found_unknown_shapes = false; const auto& op_info = op_context.op_info; // x: op_info.inputs(0) @@ -2149,32 +2290,33 @@ Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const { // kx * ky - 1 additions and 1 multiplication per output. int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * dims.kx * dims.ky; + node_costs->num_compute_ops = ops; - double total_input_size = 0; + int64 input_size; if (dims.ky >= dims.sy) { - total_input_size = - CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes); + input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes); } else { // dims.ky < dims.sy // vertical stride is larger than vertical kernel; assuming row-major // format, skip unnecessary rows (or read every kx rows per sy rows, as the // others are not used for output). const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype())); - total_input_size = - data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz; + input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz; } - const double total_output_size = - CalculateOutputSize(op_info, &found_unknown_shapes); + node_costs->num_input_bytes_accessed = {input_size}; - Costs costs = PredictOpCountBasedCost(ops, total_input_size, - total_output_size, op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - costs.max_memory = total_output_size; - return costs; + const int64 output_size = CalculateOutputSize(op_info, &found_unknown_shapes); + node_costs->num_output_bytes_accessed = {output_size}; + node_costs->max_memory = output_size; + + if (found_unknown_shapes) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return Status::OK(); } -Costs OpLevelCostEstimator::PredictAvgPoolGrad( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictAvgPoolGrad(const OpContext& op_context, + NodeCosts* node_costs) const { bool found_unknown_shapes = false; const auto& op_info = op_context.op_info; // x's shape: op_info.inputs(0) @@ -2212,22 +2354,14 @@ Costs OpLevelCostEstimator::PredictAvgPoolGrad( ops = dims.batch * dims.iz * (dims.ix * dims.iy + dims.ox * dims.oy * (dims.kx * dims.ky + 1)); } - - const double total_input_size = - CalculateInputSize(op_info, &found_unknown_shapes); - const double total_output_size = - CalculateOutputSize(op_info, &found_unknown_shapes); - - Costs costs = PredictOpCountBasedCost(ops, total_input_size, - total_output_size, op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - costs.max_memory = total_output_size; - return costs; + auto s = PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes, + node_costs); + node_costs->max_memory = node_costs->num_total_output_bytes(); + return s; } -Costs OpLevelCostEstimator::PredictFusedBatchNorm( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictFusedBatchNorm( + const OpContext& op_context, NodeCosts* node_costs) const { bool found_unknown_shapes = false; const auto& op_info = op_context.op_info; // x: op_info.inputs(0) @@ -2247,34 +2381,37 @@ Costs OpLevelCostEstimator::PredictFusedBatchNorm( } else { ops = dims.batch * dims.ix * dims.iy * dims.iz * 2; } + node_costs->num_compute_ops = ops; - const double size_nhwc = + const int64 size_nhwc = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes); - const double size_c = + const int64 size_c = CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes); - double total_input_size = 0.0; - double total_internal_read_size = 0.0; - double total_output_size = 0.0; if (is_training) { - total_input_size = size_nhwc + size_c * 2; - total_output_size = size_nhwc + size_c * 4; - total_internal_read_size = size_nhwc; + node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c}; + node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c, size_c, + size_c}; + // FusedBatchNorm in training mode internally re-reads the input tensor: + // one for mean/variance, and the 2nd internal read forthe actual scaling. + // Assume small intermediate data such as mean / variance (size_c) can be + // cached on-chip. + node_costs->internal_read_bytes = size_nhwc; } else { - total_input_size = size_nhwc + size_c * 4; - total_output_size = size_nhwc; + node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c, size_c, + size_c}; + node_costs->num_output_bytes_accessed = {size_nhwc}; } + node_costs->max_memory = node_costs->num_total_output_bytes(); - Costs costs = - PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size, - total_output_size, op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - costs.max_memory = total_output_size; - return costs; + if (found_unknown_shapes) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return Status::OK(); } -Costs OpLevelCostEstimator::PredictFusedBatchNormGrad( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictFusedBatchNormGrad( + const OpContext& op_context, NodeCosts* node_costs) const { bool found_unknown_shapes = false; const auto& op_info = op_context.op_info; // y_backprop: op_info.inputs(0) @@ -2289,25 +2426,29 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad( const auto rsqrt_cost = Eigen::internal::functor_traits< Eigen::internal::scalar_rsqrt_op>::Cost; ops = dims.iz * (dims.batch * dims.ix * dims.iy * 11 + 5 + rsqrt_cost); + node_costs->num_compute_ops = ops; - const double size_nhwc = + const int64 size_nhwc = CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes); - const double size_c = + const int64 size_c = CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes); - double total_input_size = size_nhwc * 2 + size_c * 2; - double total_internal_read_size = size_nhwc; - double total_output_size = size_nhwc * 1 + size_c * 2; + // TODO(dyoon): fix missing memory cost for variance input (size_c) and + // yet another read of y_backprop (size_nhwc) internally. + node_costs->num_input_bytes_accessed = {size_nhwc, size_nhwc, size_c, size_c}; + node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c}; + // FusedBatchNormGrad has to read y_backprop internally. + node_costs->internal_read_bytes = size_nhwc; + node_costs->max_memory = node_costs->num_total_output_bytes(); - Costs costs = - PredictOpCountBasedCost(ops, total_input_size + total_internal_read_size, - total_output_size, op_info); - costs.inaccurate = found_unknown_shapes; - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - costs.max_memory = total_output_size; - return costs; + if (found_unknown_shapes) { + node_costs->inaccurate = true; + node_costs->num_nodes_with_unknown_shapes = 1; + } + return Status::OK(); } -Costs OpLevelCostEstimator::PredictNaryOp(const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictNaryOp(const OpContext& op_context, + NodeCosts* node_costs) const { const auto& op_info = op_context.op_info; bool found_unknown_shapes = false; // Calculate the largest known tensor size across all inputs and output. @@ -2331,21 +2472,22 @@ Costs OpLevelCostEstimator::PredictNaryOp(const OpContext& op_context) const { const auto sum_cost = Eigen::internal::functor_traits< Eigen::internal::scalar_sum_op>::Cost; - Costs costs = PredictOpCountBasedCost(op_count * sum_cost, op_info); - if (found_unknown_shapes) { - costs.inaccurate = true; - } - costs.num_ops_with_unknown_shapes = found_unknown_shapes; - return costs; + return PredictDefaultNodeCosts(op_count * sum_cost, op_context, + &found_unknown_shapes, node_costs); } // softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j])) -Costs OpLevelCostEstimator::PredictSoftmax(const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictSoftmax(const OpContext& op_context, + NodeCosts* node_costs) const { bool found_unknown_shapes = false; const int64 logits_size = CalculateTensorElementCount( op_context.op_info.inputs(0), &found_unknown_shapes); - TensorShapeProto logits_shape = MaybeGetMinimumShape( - op_context.op_info.inputs(0).shape(), 2, &found_unknown_shapes); + // Softmax input rank should be >=1. + TensorShapeProto logits_shape = op_context.op_info.inputs(0).shape(); + if (logits_shape.unknown_rank() || logits_shape.dim_size() == 0) { + return errors::InvalidArgument("Softmax op has invalid input: ", + op_context.op_info.ShortDebugString()); + } #define EIGEN_COST(X) Eigen::internal::functor_traits::Cost @@ -2359,23 +2501,21 @@ Costs OpLevelCostEstimator::PredictSoftmax(const OpContext& op_context) const { EIGEN_COST(scalar_inverse_op) * logits_shape.dim(0).size(); #undef EIGEN_COST - - return PredictOpCountBasedCost(ops, op_context.op_info); + return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes, + node_costs); } -Costs OpLevelCostEstimator::PredictResizeBilinear( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictResizeBilinear( + const OpContext& op_context, NodeCosts* node_costs) const { bool found_unknown_shapes = false; if (op_context.op_info.outputs().empty() || op_context.op_info.inputs().empty()) { - return Costs::ZeroCosts(/*inaccurate=*/true); + return errors::InvalidArgument( + "ResizeBilinear op has invalid input / output ", + op_context.op_info.ShortDebugString()); } - const int64 input_size = - CalculateTensorSize(op_context.op_info.inputs(0), &found_unknown_shapes); - const int64 output_size = - CalculateTensorSize(op_context.op_info.outputs(0), &found_unknown_shapes); const int64 output_elements = CalculateTensorElementCount( op_context.op_info.outputs(0), &found_unknown_shapes); @@ -2384,7 +2524,7 @@ Costs OpLevelCostEstimator::PredictResizeBilinear( bool use_half_pixel_centers = false; if (half_pixel_centers == op_context.op_info.attr().end()) { LOG(WARNING) << "half_pixel_centers attr not set for ResizeBilinear."; - return PredictCostOfAnUnknownOp(op_context); + return PredictCostOfAnUnknownOp(op_context, node_costs); } else { use_half_pixel_centers = half_pixel_centers->second.b(); } @@ -2454,12 +2594,12 @@ Costs OpLevelCostEstimator::PredictResizeBilinear( // return top + (bottom - top) * y_lerp; ops += (add_cost * 3 + sub_cost_float * 3 + mul_cost * 3) * output_elements; - return PredictOpCountBasedCost(ops, input_size, output_size, - op_context.op_info); + return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes, + node_costs); } -Costs OpLevelCostEstimator::PredictCropAndResize( - const OpContext& op_context) const { +Status OpLevelCostEstimator::PredictCropAndResize(const OpContext& op_context, + NodeCosts* node_costs) const { bool found_unknown_shapes = false; const auto method = op_context.op_info.attr().find("method"); @@ -2472,14 +2612,9 @@ Costs OpLevelCostEstimator::PredictCropAndResize( } else { LOG(WARNING) << "method attr in CropAndResize invalid; expected bilinear " "or nearest."; - return PredictCostOfAnUnknownOp(op_context); + return PredictCostOfAnUnknownOp(op_context, node_costs); } - const int input_size = - CalculateTensorSize(op_context.op_info.inputs(0), &found_unknown_shapes); - const int output_size = - CalculateOutputSize(op_context.op_info, &found_unknown_shapes); - const int64 num_boxes = op_context.op_info.inputs(1).shape().dim(0).size(); const auto crop_shape = MaybeGetMinimumShape( op_context.op_info.outputs(0).shape(), 4, &found_unknown_shapes); @@ -2529,9 +2664,8 @@ Costs OpLevelCostEstimator::PredictCropAndResize( // Ops for innermost loop across depth. ops += cast_to_float_cost * output_elements; } - - return PredictOpCountBasedCost(ops, input_size, output_size, - op_context.op_info); + return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes, + node_costs); } } // end namespace grappler diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index 238e1595ec4..3d58d13f423 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -16,9 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ #define TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ +#include + #include "tensorflow/core/grappler/costs/cost_estimator.h" #include "tensorflow/core/grappler/costs/op_context.h" #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/padding.h" namespace tensorflow { @@ -29,6 +32,62 @@ bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto, TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, int rank, bool* found_unknown_shapes); +// Node costs; an intermediate structure used within op level cost estimator. +struct NodeCosts { + // If this FLAG is true, override calculated compute time with a minimum + // value, instead of calculating it from num_compute_ops and compute ops/sec. + // For example, PredictIdentity, PredictVariable, PredictMetadata set this + // FLAG. + bool minimum_cost_op = false; + + // Compute ops. + int64 num_compute_ops = 0; + + // Memory bytes accessed; note that these may be different to the size of + // tensors. + std::vector num_input_bytes_accessed; // ordered by input tensors. + std::vector num_output_bytes_accessed; // ordered by output ports. + int64 internal_read_bytes = 0; + int64 internal_write_bytes = 0; + + // Convenience functions. + int64 num_total_input_bytes() { + return std::accumulate(num_input_bytes_accessed.begin(), + num_input_bytes_accessed.end(), 0LL); + } + int64 num_total_read_bytes() { + return num_total_input_bytes() + internal_read_bytes; + } + int64 num_total_output_bytes() { + return std::accumulate(num_output_bytes_accessed.begin(), + num_output_bytes_accessed.end(), 0LL); + } + int64 num_total_write_bytes() { + return num_total_output_bytes() + internal_write_bytes; + } + int64 num_bytes_accessed() { + return num_total_read_bytes() + num_total_write_bytes(); + } + + // Memory usage. + int64 max_memory = 0; + int64 persistent_memory = 0; + int64 temporary_memory = 0; + + // Stats. + int64 num_nodes = 1; + int64 num_nodes_with_unknown_shapes = 0; + int64 num_nodes_with_unknown_op_type = 0; + int64 num_nodes_with_pure_memory_op = 0; + bool inaccurate = false; + + // TODO(dyoon): this is added for compatibility; some old code is hard to + // migrate; hence, using these as a backup. Once we clean up, we'll delete + // these fields. New code should not use these. + bool has_costs = false; + Costs costs; +}; + class OpLevelCostEstimator { public: OpLevelCostEstimator(); @@ -40,9 +99,7 @@ class OpLevelCostEstimator { virtual DeviceInfo GetDeviceInfo(const DeviceProperties& device) const; protected: - // Predict cost of an op for which no accurate estimator is defined. - Costs PredictCostOfAnUnknownOp(const OpContext& op_context) const; - + // TODO(dyoon): Consider to remove PredictOpCountBasedCosts() with OpInfo. // Naive cost estimate based on the given operations count and total // input/output tensor sizes of the given op_info combined. Costs PredictOpCountBasedCost(double operations, const OpInfo& op_info) const; @@ -54,6 +111,16 @@ class OpLevelCostEstimator { double output_io_bytes, const OpInfo& op_info) const; + // Top-level method cost function (PredictCosts calls this method to get + // NodeCosts, and then converts it to Costs). PredictNodeCosts() calls other + // Predict methods depending on op types. + Status PredictNodeCosts(const OpContext& op_context, + NodeCosts* node_costs) const; + + // Predict cost of an op for which no accurate estimator is defined. + Status PredictCostOfAnUnknownOp(const OpContext& op_context, + NodeCosts* node_costs) const; + // This family of routines predicts the costs to // perform the specified TensorFlow Op on the // device represented by a subclass. The default @@ -64,37 +131,64 @@ class OpLevelCostEstimator { // Implementation of costs other than // execution_time is optional, depending on the // device. - Costs PredictNaryOp(const OpContext& op_context) const; - Costs PredictConv2D(const OpContext& op_context) const; - Costs PredictCwiseOp(const OpContext& op_context) const; - Costs PredictConv2DBackpropInput(const OpContext& op_context) const; - Costs PredictConv2DBackpropFilter(const OpContext& op_context) const; - Costs PredictFusedConv2DBiasActivation(const OpContext& op_context) const; - Costs PredictMatMul(const OpContext& op_context) const; - Costs PredictSparseTensorDenseMatMul(const OpContext& op_context) const; - Costs PredictNoOp(const OpContext& op_context) const; - Costs PredictIdentity(const OpContext& op_context) const; - Costs PredictVariable(const OpContext& op_context) const; - Costs PredictBatchMatMul(const OpContext& op_context) const; - Costs PredictMetadata(const OpContext& op_context) const; - Costs PredictGatherOrSlice(const OpContext& op_context) const; - Costs PredictScatter(const OpContext& op_context) const; - Costs PredictMaxPool(const OpContext& op_context) const; - Costs PredictMaxPoolGrad(const OpContext& op_context) const; - Costs PredictAvgPool(const OpContext& op_context) const; - Costs PredictAvgPoolGrad(const OpContext& op_context) const; - Costs PredictFusedBatchNorm(const OpContext& op_context) const; - Costs PredictFusedBatchNormGrad(const OpContext& op_context) const; - Costs PredictEinsum(const OpContext& op_context) const; - Costs PredictAssignVariableOps(const OpContext& op_context) const; - Costs PredictPureMemoryOp(const OpContext& op_context) const; - Costs PredictSoftmax(const OpContext& op_context) const; - Costs PredictResizeBilinear(const OpContext& op_context) const; - Costs PredictCropAndResize(const OpContext& op_context) const; + Status PredictNaryOp(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictConv2D(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictCwiseOp(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictConv2DBackpropInput(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictConv2DBackpropFilter(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictFusedConv2DBiasActivation(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictMatMul(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictSparseTensorDenseMatMul(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictNoOp(const OpContext& op_context, NodeCosts* node_costs) const; + Status PredictIdentity(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictVariable(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictBatchMatMul(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictMetadata(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictGatherOrSlice(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictScatter(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictMaxPool(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictMaxPoolGrad(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictAvgPool(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictAvgPoolGrad(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictFusedBatchNorm(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictFusedBatchNormGrad(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictEinsum(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictAssignVariableOps(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictPureMemoryOp(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictSoftmax(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictResizeBilinear(const OpContext& op_context, + NodeCosts* node_costs) const; + Status PredictCropAndResize(const OpContext& op_context, + NodeCosts* node_costs) const; // Generic cost prediction method for fused operations. - Costs PredictFusedOp(const OpContext& op_context, - const std::vector& fused_op_contexts) const; + Status PredictFusedOp(const OpContext& op_context, + const std::vector& fused_op_contexts, + NodeCosts* node_costs) const; // Utility function for safe division. Returns 0 // if rhs is 0 or negative. @@ -176,11 +270,19 @@ class OpLevelCostEstimator { static int64 CalculateInputSize(const OpInfo& op_info, bool* found_unknown_shapes); + // Same, but a vector format: one for each input. + static std::vector CalculateInputTensorSize( + const OpInfo& op_info, bool* found_unknown_shapes); + // Calculate the total size in bytes of the all // the outputs of specified TensorFlow op. static int64 CalculateOutputSize(const OpInfo& op_info, bool* found_unknown_shapes); + // Same, but a vector format: one for each output. + static std::vector CalculateOutputTensorSize( + const OpInfo& op_info, bool* found_unknown_shapes); + // For convolution and its grad ops. static ConvolutionDimensions ConvolutionDimensionsFromInputs( const TensorShapeProto& original_image_shape, @@ -203,9 +305,16 @@ class OpLevelCostEstimator { static OpInfo::TensorProperties DescribeTensor( DataType type, const std::vector& dims); + // Helper method for building common case NodeCosts. + static Status PredictDefaultNodeCosts(const int64 num_compute_ops, + const OpContext& op_context, + bool* found_unknown_shapes, + NodeCosts* node_costs); + protected: std::map elementwise_ops_; - typedef std::function CostImpl; + typedef std::function + CostImpl; std::map device_cost_impl_; // If true, assume compute and memory overlap; hence, the op cost is max of // compute_time and memory_time, instead of sum of those two. diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index fb2e4452b43..23373d3dc1b 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -894,8 +894,8 @@ TEST_F(OpLevelCostEstimatorTest, 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false, "NCHW", "HWIO")); EXPECT_EQ(Costs::Duration(825345), cost.memory_time); - EXPECT_EQ(Costs::Duration(355321038), cost.compute_time); - EXPECT_EQ(Costs::Duration(356146383), cost.execution_time); + EXPECT_EQ(Costs::Duration(355321037), cost.compute_time); + EXPECT_EQ(Costs::Duration(356146382), cost.execution_time); EXPECT_EQ(cost.num_ops_total, 1); EXPECT_FALSE(cost.inaccurate); EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0); @@ -908,8 +908,8 @@ TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_HWIO) { 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true, "NCHW", "HWIO")); EXPECT_EQ(Costs::Duration(1416808), cost.memory_time); - EXPECT_EQ(Costs::Duration(355616770), cost.compute_time); - EXPECT_EQ(Costs::Duration(357033578), cost.execution_time); + EXPECT_EQ(Costs::Duration(355616768), cost.compute_time); + EXPECT_EQ(Costs::Duration(357033576), cost.execution_time); EXPECT_EQ(cost.num_ops_total, 1); EXPECT_FALSE(cost.inaccurate); EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0); @@ -922,8 +922,8 @@ TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW) { 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true, "NCHW", "OIHW")); EXPECT_EQ(Costs::Duration(1416808), cost.memory_time); - EXPECT_EQ(Costs::Duration(355616770), cost.compute_time); - EXPECT_EQ(Costs::Duration(357033578), cost.execution_time); + EXPECT_EQ(Costs::Duration(355616768), cost.compute_time); + EXPECT_EQ(Costs::Duration(357033576), cost.execution_time); EXPECT_EQ(cost.num_ops_total, 1); EXPECT_FALSE(cost.inaccurate); EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0); @@ -936,8 +936,8 @@ TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_HWIO) { 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true, "NHWC", "HWIO")); EXPECT_EQ(Costs::Duration(1416808), cost.memory_time); - EXPECT_EQ(Costs::Duration(355616770), cost.compute_time); - EXPECT_EQ(Costs::Duration(357033578), cost.execution_time); + EXPECT_EQ(Costs::Duration(355616768), cost.compute_time); + EXPECT_EQ(Costs::Duration(357033576), cost.execution_time); EXPECT_EQ(cost.num_ops_total, 1); EXPECT_FALSE(cost.inaccurate); EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0); @@ -950,8 +950,8 @@ TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_OIHW) { 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true, "NHWC", "OIHW")); EXPECT_EQ(Costs::Duration(1416808), cost.memory_time); - EXPECT_EQ(Costs::Duration(355616770), cost.compute_time); - EXPECT_EQ(Costs::Duration(357033578), cost.execution_time); + EXPECT_EQ(Costs::Duration(355616768), cost.compute_time); + EXPECT_EQ(Costs::Duration(357033576), cost.execution_time); EXPECT_EQ(cost.num_ops_total, 1); EXPECT_FALSE(cost.inaccurate); EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0); @@ -964,8 +964,8 @@ TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_VECT_C_OIHW) { 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true, "NCHW_VECT_C", "OIHW")); EXPECT_EQ(Costs::Duration(1416808), cost.memory_time); - EXPECT_EQ(Costs::Duration(355616770), cost.compute_time); - EXPECT_EQ(Costs::Duration(357033578), cost.execution_time); + EXPECT_EQ(Costs::Duration(355616768), cost.compute_time); + EXPECT_EQ(Costs::Duration(357033576), cost.execution_time); EXPECT_EQ(cost.num_ops_total, 1); EXPECT_FALSE(cost.inaccurate); EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0); @@ -978,8 +978,8 @@ TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW_VECT_I) { 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true, "NCHW", "OIHW_VECT_I")); EXPECT_EQ(Costs::Duration(1416808), cost.memory_time); - EXPECT_EQ(Costs::Duration(355616770), cost.compute_time); - EXPECT_EQ(Costs::Duration(357033578), cost.execution_time); + EXPECT_EQ(Costs::Duration(355616768), cost.compute_time); + EXPECT_EQ(Costs::Duration(357033576), cost.execution_time); EXPECT_EQ(cost.num_ops_total, 1); EXPECT_FALSE(cost.inaccurate); EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0); @@ -993,8 +993,8 @@ TEST_F(OpLevelCostEstimatorTest, 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true, "NCHW_VECT_C", "OIHW_VECT_I")); EXPECT_EQ(Costs::Duration(1416808), cost.memory_time); - EXPECT_EQ(Costs::Duration(355616770), cost.compute_time); - EXPECT_EQ(Costs::Duration(357033578), cost.execution_time); + EXPECT_EQ(Costs::Duration(355616768), cost.compute_time); + EXPECT_EQ(Costs::Duration(357033576), cost.execution_time); EXPECT_EQ(cost.num_ops_total, 1); EXPECT_FALSE(cost.inaccurate); EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0); @@ -2255,9 +2255,14 @@ TEST_F(OpLevelCostEstimatorTest, CropAndResizeExecutionTime) { DescribeTensor4D(kNumBoxes, kOutputImageDim, kOutputImageDim, kChannelSize, op_context.op_info.add_outputs()); + // Note this is time [ns, default in Duration in Costs], not bytes; + // whereas memory bandwidth from SetCpuDevice() is 10GB/s. const int kExpectedMemoryTime = - (kImageDim * kImageDim + kNumBoxes * kOutputImageDim * kOutputImageDim) * - 4; + (kImageDim * kImageDim * 4 + // input image in float. + kNumBoxes * 4 * 8 / 10 + // boxes (kNumBoxes x 4) in int64. + kNumBoxes * kOutputImageDim * kOutputImageDim * 4); // output in float. + // Note that input image and output image has kChannelSize dim, which is 10, + // hence, no need to divide it by 10 (bandwidth). { // Cost of CropAndResize with bilinear interpolation.