[tf.sparse_segment_*()] Grappler optimization to avoid casts for indices/segment_ids.
The sparse segment reduction ops recently gained kernels that allow DT_INT32 or DT_INT64 values for the "indices" and "segment_ids" inputs. Previously, a lot of libraries included casts from DT_INT64 to DT_INT32 before passing these inputs, which entails an unnecessary copy of the input data. This change adds an arithmetic optimizer stage that removes these casts where possible. PiperOrigin-RevId: 308439398 Change-Id: I2a4ad2d0be294e93f1746f9538f7b0a5a610e46d
This commit is contained in:
parent
ecaa695cb1
commit
12f5bf916b
@ -3571,6 +3571,65 @@ class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage {
|
||||
}
|
||||
};
|
||||
|
||||
// Eliminates unnecessary casts before sparse segment reduction operations.
|
||||
//
|
||||
// Existing graphs and library code would often insert a cast from DT_INT64 to
|
||||
// DT_INT32 on the indices and/or segment_ids inputs to "SparseSegment*" ops.
|
||||
// Support for for DT_INT64 indices and/or segment_ids now exists, so we can
|
||||
// pass the input directly without a cast.
|
||||
class RemoveCastIntoSegmentReductionStage : public ArithmeticOptimizerStage {
|
||||
public:
|
||||
explicit RemoveCastIntoSegmentReductionStage(
|
||||
const GraphOptimizerContext& ctx,
|
||||
const ArithmeticOptimizerContext& ctx_ext)
|
||||
: ArithmeticOptimizerStage("RemoveCastIntoSegmentReductionStage", ctx,
|
||||
ctx_ext) {}
|
||||
~RemoveCastIntoSegmentReductionStage() override = default;
|
||||
|
||||
bool IsSupported(const NodeDef* node) const override {
|
||||
return IsAnySparseSegmentReduction(*node);
|
||||
}
|
||||
|
||||
Status TrySimplify(NodeDef* reduction_node,
|
||||
string* simplified_node_name) override {
|
||||
if (IsInPreserveSet(*reduction_node)) return Status::OK();
|
||||
|
||||
bool optimized = false;
|
||||
|
||||
// Inputs 1 (indices) and 2 (segment_ids) can be either DT_INT32 or
|
||||
// DT_INT64.
|
||||
std::array<std::pair<int, string>, 2> input_details = {
|
||||
std::make_pair(1, "Tidx"), std::make_pair(2, "Tsegmentids")};
|
||||
|
||||
for (const auto& input : input_details) {
|
||||
int input_index = input.first;
|
||||
const string& type_attr_name = input.second;
|
||||
NodeDef* cast_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetInputNode(reduction_node->input(input_index), &cast_node));
|
||||
DataType original_index_type;
|
||||
if (IsCastFromSupportedType(*cast_node, &original_index_type)) {
|
||||
reduction_node->set_input(input_index, cast_node->input(0));
|
||||
ctx().node_map->UpdateInput(reduction_node->name(),
|
||||
reduction_node->input(1),
|
||||
cast_node->input(0));
|
||||
SetDataTypeToAttr(original_index_type, type_attr_name, reduction_node);
|
||||
optimized = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (optimized) *simplified_node_name = reduction_node->name();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
bool IsCastFromSupportedType(const NodeDef& node, DataType* out_input_type) {
|
||||
if (!IsCast(node)) return false;
|
||||
if (!GetNodeAttr(node, "SrcT", out_input_type).ok()) return false;
|
||||
return *out_input_type == DT_INT32 || *out_input_type == DT_INT64;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
|
||||
@ -3645,6 +3704,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
|
||||
pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
|
||||
if (options_.simplify_embedding_lookup)
|
||||
pipeline.AddStage<SimplifyEmbeddingLookupStage>(ctx, ctx_ext);
|
||||
if (options_.remove_cast_into_segment_reduction)
|
||||
pipeline.AddStage<RemoveCastIntoSegmentReductionStage>(ctx, ctx_ext);
|
||||
|
||||
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
|
||||
<< absl::StrJoin(pipeline.StageNames(), ", ");
|
||||
|
@ -86,6 +86,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
|
||||
bool unary_ops_composition = true;
|
||||
bool remove_stack_slice_same_axis = true;
|
||||
bool simplify_embedding_lookup = true;
|
||||
bool remove_cast_into_segment_reduction = true;
|
||||
|
||||
// Choose which arithmetic optimizer stages will be enabled for a given
|
||||
// optimization level by default.
|
||||
|
@ -4129,5 +4129,49 @@ TEST_F(ArithmeticOptimizerTest, SimplifyEmbeddingLookup) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, RemoveCastIntoSegmentReduction) {
|
||||
for (DataType indices_type : {DT_INT32, DT_INT64}) {
|
||||
for (DataType segment_ids_type : {DT_INT32, DT_INT64}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output embeddings = ops::Const(s.WithOpName("embeddings"),
|
||||
{1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
Output indices =
|
||||
ops::Cast(s.WithOpName("cast_indices"),
|
||||
ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1}),
|
||||
indices_type);
|
||||
Output segment_ids = ops::Cast(
|
||||
s.WithOpName("cast_segment_ids"),
|
||||
ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2}),
|
||||
segment_ids_type);
|
||||
Output result = ops::SparseSegmentSum(s.WithOpName("result"), embeddings,
|
||||
indices, segment_ids);
|
||||
Output id = ops::Identity(s.WithOpName("id"), result);
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
item.fetch = {"id"};
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
|
||||
GraphDef output;
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyRemoveCastIntoSegmentReduction(&optimizer);
|
||||
OptimizeAndPrune(&optimizer, &item, &output);
|
||||
|
||||
for (const auto& node : output.node()) {
|
||||
if (node.name() == "result") {
|
||||
EXPECT_EQ(node.input(1), "indices");
|
||||
EXPECT_EQ(node.input(2), "segment_ids");
|
||||
}
|
||||
EXPECT_NE(node.op(), "Cast");
|
||||
}
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -228,6 +228,12 @@ class ArithmeticOptimizerTest : public GrapplerTest {
|
||||
optimizer->options_.simplify_embedding_lookup = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveCastIntoSegmentReduction(
|
||||
ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_cast_into_segment_reduction = true;
|
||||
}
|
||||
|
||||
private:
|
||||
void DisableAllStages(ArithmeticOptimizer* optimizer) {
|
||||
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
|
||||
@ -256,6 +262,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
|
||||
options.simplify_aggregation = false;
|
||||
options.unary_ops_composition = false;
|
||||
options.simplify_embedding_lookup = false;
|
||||
options.remove_cast_into_segment_reduction = false;
|
||||
optimizer->options_ = options;
|
||||
}
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user