Add support for AddV2, Assign[|Add|Sub]VariableOp for Op Level cost estimator.
PiperOrigin-RevId: 292558662 Change-Id: I8a41e267bddb375f5acee9e575aa551136dc18ba
This commit is contained in:
parent
7e5faef4f3
commit
00712963cc
tensorflow/core/grappler/costs
@ -74,7 +74,7 @@ struct Costs {
|
||||
inline Costs();
|
||||
|
||||
// Builds a Costs structure with all zero values, rather than unknowns.
|
||||
static inline Costs ZeroCosts();
|
||||
static inline Costs ZeroCosts(bool inaccurate = false);
|
||||
|
||||
struct MilliSeconds : std::chrono::milliseconds {
|
||||
MilliSeconds() : std::chrono::milliseconds(0) {}
|
||||
@ -190,7 +190,7 @@ Costs::Costs() {
|
||||
max_per_op_streaming = kMemoryUnknown;
|
||||
}
|
||||
|
||||
Costs Costs::ZeroCosts() {
|
||||
Costs Costs::ZeroCosts(bool inaccurate) {
|
||||
Costs costs;
|
||||
costs.execution_time = Duration::zero();
|
||||
costs.compute_time = Duration::zero();
|
||||
@ -201,6 +201,7 @@ Costs Costs::ZeroCosts() {
|
||||
costs.temporary_memory = kZeroMemory;
|
||||
costs.max_per_op_buffers = kZeroMemory;
|
||||
costs.max_per_op_streaming = kZeroMemory;
|
||||
costs.inaccurate = inaccurate;
|
||||
return costs;
|
||||
}
|
||||
|
||||
|
@ -93,6 +93,9 @@ constexpr char kVarHandleOp[] = "VarHandleOp";
|
||||
constexpr char kVarHandlesOp[] = "_VarHandlesOp";
|
||||
constexpr char kReadVariableOp[] = "ReadVariableOp";
|
||||
constexpr char kReadVariablesOp[] = "_ReadVariablesOp";
|
||||
constexpr char kAssignVariableOp[] = "AssignVariableOp";
|
||||
constexpr char kAssignAddVariableOp[] = "AssignAddVariableOp";
|
||||
constexpr char kAssignSubVariableOp[] = "AssignSubVariableOp";
|
||||
|
||||
static const Costs::Duration kMinComputeTime(1);
|
||||
|
||||
@ -375,6 +378,14 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
||||
device_cost_impl_.emplace(
|
||||
kFusedBatchNormGrad,
|
||||
wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad));
|
||||
device_cost_impl_.emplace(
|
||||
kAssignVariableOp, wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
|
||||
device_cost_impl_.emplace(
|
||||
kAssignAddVariableOp,
|
||||
wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
|
||||
device_cost_impl_.emplace(
|
||||
kAssignSubVariableOp,
|
||||
wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
|
||||
|
||||
persistent_ops_ = {
|
||||
kConst, kVariable, kVariableV2, kAutoReloadVariable,
|
||||
@ -435,6 +446,7 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
||||
elementwise_ops_.emplace("Tan", EIGEN_COST(scalar_tan_op<float>));
|
||||
// Binary ops alphabetically sorted
|
||||
elementwise_ops_.emplace("Add", EIGEN_COST(scalar_sum_op<float>));
|
||||
elementwise_ops_.emplace("AddV2", EIGEN_COST(scalar_sum_op<float>));
|
||||
elementwise_ops_.emplace("ApproximateEqual", 1);
|
||||
elementwise_ops_.emplace("BiasAdd", EIGEN_COST(scalar_sum_op<float>));
|
||||
elementwise_ops_.emplace("QuantizedBiasAdd",
|
||||
@ -1885,6 +1897,28 @@ Costs OpLevelCostEstimator::PredictMaxPoolGrad(
|
||||
return costs;
|
||||
}
|
||||
|
||||
/* This predict function handles three types of tensorflow ops
|
||||
* AssignVariableOp/AssignAddVariableOp/AssignSubVariableOp, broadcasting
|
||||
* was not possible for these ops, therefore the input tensor's shapes is
|
||||
* enough to compute the cost */
|
||||
Costs OpLevelCostEstimator::PredictAssignVariableOps(
|
||||
const OpContext& op_context) const {
|
||||
bool found_unknown_shapes = false;
|
||||
const auto& op_info = op_context.op_info;
|
||||
/* First input of these ops are reference to the assignee. */
|
||||
if (op_info.inputs_size() != 2) return Costs::ZeroCosts(true);
|
||||
const double total_input_size =
|
||||
CalculateInputSize(op_info, &found_unknown_shapes);
|
||||
const double flops = op_info.op() == kAssignVariableOp
|
||||
? 0.0
|
||||
: CalculateTensorElementCount(op_info.inputs(1),
|
||||
&found_unknown_shapes);
|
||||
Costs costs = PredictOpCountBasedCost(flops, total_input_size, 0, op_info);
|
||||
costs.inaccurate = found_unknown_shapes;
|
||||
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
|
||||
return costs;
|
||||
}
|
||||
|
||||
Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const {
|
||||
bool found_unknown_shapes = false;
|
||||
const auto& op_info = op_context.op_info;
|
||||
|
@ -85,6 +85,7 @@ class OpLevelCostEstimator {
|
||||
Costs PredictFusedBatchNorm(const OpContext& op_context) const;
|
||||
Costs PredictFusedBatchNormGrad(const OpContext& op_context) const;
|
||||
Costs PredictEinsum(const OpContext& op_context) const;
|
||||
Costs PredictAssignVariableOps(const OpContext& op_context) const;
|
||||
|
||||
// Generic cost prediction method for fused operations.
|
||||
Costs PredictFusedOp(const OpContext& op_context,
|
||||
|
@ -164,6 +164,10 @@ OpContext DescribeEinsum(const std::vector<int>& dims_a,
|
||||
return op_context;
|
||||
}
|
||||
|
||||
void DescribeDummyTensor(OpInfo::TensorProperties* tensor) {
|
||||
// Intentionally leave the tensor shape and type information missing.
|
||||
}
|
||||
|
||||
// Wrangles the minimum number of proto fields to set up a 1D Tensor for cost
|
||||
// estimation purposes.
|
||||
void DescribeTensor1D(int dim0, OpInfo::TensorProperties* tensor) {
|
||||
@ -1715,5 +1719,34 @@ TEST_F(OpLevelCostEstimatorTest, Einsum) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpLevelCostEstimatorTest, PredictResourceVariableOps) {
|
||||
TestOpLevelCostEstimator estimator;
|
||||
estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/1, /*gb_per_sec=*/1));
|
||||
|
||||
{
|
||||
OpContext op_context;
|
||||
op_context.op_info.set_op("AssignVariableOp");
|
||||
DescribeDummyTensor(op_context.op_info.add_inputs());
|
||||
DescribeTensor1D(100, op_context.op_info.add_inputs());
|
||||
auto cost = estimator.PredictCosts(op_context);
|
||||
EXPECT_EQ(Costs::Duration(400), cost.memory_time);
|
||||
EXPECT_EQ(Costs::Duration(0), cost.compute_time);
|
||||
EXPECT_EQ(Costs::Duration(400), cost.execution_time);
|
||||
EXPECT_FALSE(cost.inaccurate);
|
||||
}
|
||||
|
||||
{
|
||||
OpContext op_context;
|
||||
op_context.op_info.set_op("AssignSubVariableOp");
|
||||
DescribeDummyTensor(op_context.op_info.add_inputs());
|
||||
DescribeTensor1D(100, op_context.op_info.add_inputs());
|
||||
auto cost = estimator.PredictCosts(op_context);
|
||||
EXPECT_EQ(Costs::Duration(400), cost.memory_time);
|
||||
EXPECT_EQ(Costs::Duration(100), cost.compute_time);
|
||||
EXPECT_EQ(Costs::Duration(400), cost.execution_time);
|
||||
EXPECT_FALSE(cost.inaccurate);
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user