STT-tensorflow/tensorflow/compiler/tf2xla/lib/scatter.cc
Mark Heffernan 1ed59e52b1 Replace calls to ShapeUtil::Rank with Shape::rank.
No functional change. ShapeUtil::Rank is marked as deprecated. A later CL will remove it.

PiperOrigin-RevId: 226412117
2018-12-20 16:38:08 -08:00

207 lines
7.8 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
xla::StatusOr<xla::XlaOp> XlaScatter(
const xla::XlaOp& buffer, const xla::XlaOp& updates,
const xla::XlaOp& indices, bool indices_are_vectors,
const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
combiner,
xla::XlaBuilder* builder) {
TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer));
TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates));
TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices));
absl::Span<const int64> indices_dims =
xla::AsInt64Slice(indices_shape.dimensions());
// If the indices are N-dimensional, the minor dimension of indices contains
// the indices to update. Otherwise the indices are all scalars.
int64 num_index_dims = 1;
if (indices_are_vectors) {
TF_RET_CHECK(!indices_dims.empty());
num_index_dims = indices_dims.back();
if (num_index_dims > buffer_shape.rank()) {
return errors::InvalidArgument(
"The size of the minor dimension of the indices (shape: ",
xla::ShapeUtil::HumanString(indices_shape),
") must be <= the rank of the buffer (shape: ",
xla::ShapeUtil::HumanString(buffer_shape), ")");
}
indices_dims.remove_suffix(1);
}
int64 num_indices = 1;
for (int64 dim : indices_dims) {
num_indices *= dim;
}
// Degenerate case: nothing to update. Return the buffer unchanged.
if (num_indices == 0) {
return buffer;
}
// If any of the indexed dimensions are zero in the buffer, the update cannot
// succeed since it updates a slice of size 1.
for (int64 i = 0; i < num_index_dims; ++i) {
if (xla::ShapeUtil::GetDimension(buffer_shape, i) == 0) {
return errors::InvalidArgument("Scatter dimension ", i,
" is of size zero in tensor with shape ",
xla::ShapeUtil::HumanString(buffer_shape));
}
}
// Example of a 1-D scatter that updates two [3,1] tensors in a tensor of
// shape [3,3]:
// NOTE: ***This case will not be generated by any of the tf.scatter ops.***
//
// operand = s32[3,3] parameter(0)
// indices = s32[2] parameter(1)
// updates = s32[3,2] parameter(2)
// scatter = s32[3,3] scatter(operand, indices, updates),
// to_apply=update_computation,
// update_window_dims={0},
// inserted_window_dims={1},
// scatter_dims_to_operand_dims={1},
// index_vector_dim=1
//
//
// Example of a 1-D scatter that updates two [1,3] tensors in a tensor of
// shape [3,3]:
//
// operand = s32[3,3] parameter(0)
// indices = s32[2] parameter(1)
// updates = s32[2,3] parameter(2)
// scatter = s32[3,3] scatter(operand, indices, updates),
// to_apply=update_computation,
// update_window_dims={1},
// inserted_window_dims={0},
// scatter_dims_to_operand_dims={0},
// index_vector_dim=1
//
//
// Example of an N-D scatter updating slices of shape [1,1,2] in a tensor of
// shape [3,3,2]
//
// operand = s32[3,3,2] parameter(0)
// indices = s32[2,2] parameter(1)
// updates = s32[2,2] parameter(2)
// scatter = s32[3,3,2] scatter(operand, indices, updates),
// to_apply=update_computation,
// update_window_dims={1},
// inserted_window_dims={0,1},
// scatter_dims_to_operand_dims={0,1},
// index_vector_dim=1
//
//
// Example of a scatter updating slices of shape [] in a tensor of shape [1,1]
//
// operand = s32[1,1] parameter(0)
// indices = s32[1] parameter(1)
// updates = s32[1] parameter(2)
// scatter = s32[1,1] scatter(operand, indices, updates),
// to_apply=update_computation,
// update_window_dims={},
// inserted_window_dims={0,1},
// scatter_dims_to_operand_dims={0},
// index_vector_dim=1
// Note that updates operand would be broadcasted into [1] in this case.
//
xla::ScatterDimensionNumbers dim_numbers;
dim_numbers.set_index_vector_dim(indices_are_vectors
? indices_shape.dimensions_size() - 1
: indices_shape.dimensions_size());
int64 updates_rank = updates_shape.rank();
int64 buffer_rank = buffer_shape.rank();
int64 num_window_dims_in_updates = buffer_rank - num_index_dims;
// If the rank of `updates` is 0 and does not match the expected rank of
// updates, broadcast `updates` to the expected shape of updates.
auto new_updates = updates;
std::vector<int64> expected_updates_dims(indices_dims.begin(),
indices_dims.end());
for (int64 dim = num_index_dims; dim < buffer_rank; ++dim) {
expected_updates_dims.push_back(buffer_shape.dimensions(dim));
}
int64 expected_updates_rank = expected_updates_dims.size();
if (updates_rank == 0 && expected_updates_rank != 0) {
new_updates = xla::Broadcast(updates, expected_updates_dims);
TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates));
updates_rank = updates_shape.rank();
}
if (updates_rank > 0) {
for (int64 i = (updates_rank - num_window_dims_in_updates);
i < updates_rank; ++i) {
dim_numbers.add_update_window_dims(i);
}
}
for (int64 i = 0; i < num_index_dims; ++i) {
dim_numbers.add_inserted_window_dims(i);
dim_numbers.add_scatter_dims_to_operand_dims(i);
}
// Build the combiner computation.
xla::XlaComputation combiner_computation;
{
xla::XlaBuilder cb("scatter-combiner");
auto xla_scalar_shape =
xla::ShapeUtil::MakeShape(buffer_shape.element_type(), {});
auto p0 = xla::Parameter(&cb, 0, xla_scalar_shape, "p0");
auto p1 = xla::Parameter(&cb, 1, xla_scalar_shape, "p1");
if (combiner) {
combiner(p0, p1, &cb);
}
combiner_computation = cb.Build().ConsumeValueOrDie();
}
VLOG(3) << "Scatter op:";
VLOG(3) << " Input: " << xla::ShapeUtil::HumanString(buffer_shape);
VLOG(3) << " Indices: " << xla::ShapeUtil::HumanString(indices_shape);
VLOG(3) << " Updates: " << xla::ShapeUtil::HumanString(updates_shape);
VLOG(3) << " Scatter Dimension Numbers: ";
VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim();
VLOG(3) << " update_window_dims: ["
<< absl::StrJoin(dim_numbers.update_window_dims(), ",") << "]";
VLOG(3) << " inserted_window_dims: ["
<< absl::StrJoin(dim_numbers.inserted_window_dims(), ",") << "]";
VLOG(3) << " scatter_dims_to_operand_dims: ["
<< absl::StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ",")
<< "]";
return xla::Scatter(buffer, indices, new_updates, combiner_computation,
dim_numbers);
}
} // namespace tensorflow