END_PUBLIC Note: this CL will break builds. cl/159887762 to follow to fix all the breakages. --- Commit2336cdf7f
authored by Maxwell Paul Brickner<mbrickn@users.noreply.github.com> Committed by gunan<gunan@google.com>: Updated link to use HTTPS (#10998) Howdy! I just updated a link to use https instead of http. Thanks! --- Commitad0892df1
authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes run_metadata_test for SYCL This test is designed to test CUDA specific behavior --- Commit6b37a0725
authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update comments --- Commit1699d904a
authored by John Lawson<john@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes CUDA specific test run on SYCL (#56) The testBadParentValuesOnGPU should only be run on CUDA devices, as the test checks for particular CUDA behaviour. We don't actually provide a SYCL kernel for GatherTree and so it's not a problem that the tests don't target SYCL. --- Commit3c1946230
authored by myPrecious<Moriadry@users.noreply.github.com> Committed by Shanqing Cai<cais@google.com>: Java API to get the size of specified input list of operations. (#10865) * Java API to get the size of specified input list of operations * remove unnecessary explain to avoid bring a new term to users. --- Commite911c7480
authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] REGISTER -> REGISTER6 --- Commitfbf6c4cec
authored by superryanguo<superryanguo@gmail.com> Committed by superryanguo<superryanguo@gmail.com>: Simplify the Quickstart section with the weblink is better --- Commit72e2918cc
authored by Taehoon Lee<taehoonlee@snu.ac.kr> Committed by Taehoon Lee<taehoonlee@snu.ac.kr>: Fix typos --- Commit90c4406b7
authored by Rishabh Patel<patelrishabh@users.noreply.github.com> Committed by GitHub<noreply@github.com>: Correct the learning rate as per the code snippet --- Commit03da61134
authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update ir_array.cc --- Commit2df6cd3ac
authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Another try --- Commitaf0cbace1
authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Transpose to go through Eigen (#10321) --- Commitfc7361081
authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers RGBToHSV and HSVToRGB (#91) (#10848) * [OpenCL] Added RGBToHSV and HSVToRGB * Aligning '\' --- Commit832894ef8
authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers AdjustContrastv2 (#10949) * [OpenCL] Registers AdjustContrastv2 (#93) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL (#96) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL * simplified to #ifndef * Changed to "#if GOOGLE_CUDA" * Update adjust_contrast_op_benchmark_test.cc * Added comments --- Commitcb4c2f8d1
authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make TransferBufferToInFeed not virual so it compiles. --- Commite89f04d80
authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix calling Literal member functions. --- Commit15a8df724
authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix mac build clone from meheff's change: [XLA] Change return type of DeviceAssignment::Deserialize to fix build breakage on mac. The mac build had the following error: error: incomplete type 'xla::DeviceAssignment' used in type trait expression This was due to a static method returning a StatusOr<DeviceAssignment> inside of the definition of DeviceAssignment. --- Commita54d43fa4
authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Replace LiteralUtil to Literal in compiler/plugin/executor --- Commit88a6bb80c
authored by Guenther Schmuelling<guschmue@microsoft.com> Committed by Guenther Schmuelling<guschmue@microsoft.com>: expand inline for debug builds to limit number of symbols --- Commit62fb49d31
authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix visibility error for contrib/remote_fused_graph/pylib/BUILD. --- Commit4c75252f2
authored by Mark Neumann<markn@allenai.org> Committed by Mark Neumann<markn@allenai.org>: fix initial test values to avoid numerical instability --- Commitb58d98353
authored by sj6077<epik03sj@gmail.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: Fixes of AutoParallel bug (#10368) * Fix the bug that auto_parallel could replicate variable snapshot name * Use NodeName in grappler:utils instead of substr, convert variables->variable_def of grappler item * remove variable_def from grappler item, exclude snapshot nodes from dont_replicate_nodes in auto_parallel --- Commita286b7db8
authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make debug_test slice integer. --- Commit97fcfdfa6
authored by Toby Boyd<tobyboyd@google.com> Committed by GitHub<noreply@github.com>: Fixed path to seq2seq.py and minor formatting --- Commit63c1befb8
authored by Anish Shah<shah.anish07@gmail.com> Committed by Anish Shah<shah.anish07@gmail.com>: Improve docs for tf.nn.depthwise_conv2d_native --- Commit8d42202b2
authored by Yong Tang<yong.tang.github@outlook.com> Committed by Yong Tang<yong.tang.github@outlook.com>: Fix mismatched delete in mkl_tfconv_op.cc This fix fixes mismatched new[]-delete in mkl_tfconv_op.cc (the file went through clang-format so there are some additional changes) Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- Commit26301bd55
authored by Danny Goodman<goodman.danny@gmail.com> Committed by Danny Goodman<goodman.danny@gmail.com>: fix error format --- Commitb3f33ad46
authored by Yao Zhang<yaozhang@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make changes to prepare for the fused option of batch norm to be set to None (None means using fused batch norm if possible). PiperOrigin-RevId: 159649743 --- Commita4a469832
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Add tests for select ops and while loops that produce tuples that contain predicates. PiperOrigin-RevId: 159645900 --- Commit980d3f2be
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use C API to implement Operation.name property This name property is used in many existing tests including those that already run with C API enabled (math_ops_test, framework_ops_test, session_test, session_partial_run_test, math_ops_test_gpu, etc). PiperOrigin-RevId: 159645767 --- Commit26239c706
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Previously we didn't have an implementation of BatchNormInference and BatchNormTraining, which gives a linker error if anyone ever tries to call that. A dummy implementation is friendlier than a linker error. PiperOrigin-RevId: 159645612 --- Commitf671c5caa
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 159570549 PiperOrigin-RevId: 160182040
327 lines
13 KiB
C++
327 lines
13 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/core/util/strided_slice_op.h"
|
|
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
|
#include "tensorflow/core/framework/op_kernel.h"
|
|
#include "tensorflow/core/framework/register_types.h"
|
|
#include "tensorflow/core/framework/tensor.h"
|
|
#include "tensorflow/core/kernels/ops_util.h"
|
|
#include "tensorflow/core/lib/core/status.h"
|
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
|
#include "tensorflow/core/platform/mem.h"
|
|
|
|
namespace tensorflow {
|
|
namespace {
|
|
|
|
class StridedSliceOp : public XlaOpKernel {
|
|
public:
|
|
explicit StridedSliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
|
|
}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape(0);
|
|
|
|
TensorShape final_shape;
|
|
gtl::InlinedVector<int64, 4> begin;
|
|
gtl::InlinedVector<int64, 4> end;
|
|
gtl::InlinedVector<int64, 4> strides;
|
|
|
|
xla::Literal begin_literal, end_literal, strides_literal;
|
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
|
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
|
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
|
|
|
|
Tensor begin_tensor, end_tensor, strides_tensor;
|
|
OP_REQUIRES_OK(
|
|
ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
|
|
OP_REQUIRES_OK(ctx,
|
|
LiteralToHostTensor(end_literal, index_type_, &end_tensor));
|
|
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
|
&strides_tensor));
|
|
|
|
TensorShape dummy_processing_shape;
|
|
bool dummy = false;
|
|
OP_REQUIRES_OK(ctx,
|
|
ValidateStridedSliceOp(
|
|
&begin_tensor, &end_tensor, strides_tensor, input_shape,
|
|
begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
|
|
shrink_axis_mask_, &dummy_processing_shape, &final_shape,
|
|
&dummy, &dummy, &dummy, &begin, &end, &strides));
|
|
|
|
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
|
|
gtl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
|
|
|
|
for (int i = 0; i < begin.size(); ++i) {
|
|
if (strides[i] > 0) {
|
|
slice_begin.push_back(begin[i]);
|
|
slice_end.push_back(end[i]);
|
|
slice_strides.push_back(strides[i]);
|
|
} else {
|
|
// Negative stride: swap begin and end, add 1 because the interval
|
|
// is semi-open, and mark the dimension to be reversed.
|
|
slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1);
|
|
slice_end.push_back(input_shape.dim_size(i) - end[i] - 1);
|
|
slice_strides.push_back(-strides[i]);
|
|
dimensions_to_reverse.push_back(i);
|
|
}
|
|
}
|
|
|
|
xla::ComputationDataHandle slice = ctx->Input(0);
|
|
if (!dimensions_to_reverse.empty()) {
|
|
slice = ctx->builder()->Rev(slice, dimensions_to_reverse);
|
|
}
|
|
|
|
slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides);
|
|
|
|
slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes());
|
|
ctx->SetOutput(0, slice);
|
|
}
|
|
|
|
private:
|
|
int32 begin_mask_, end_mask_;
|
|
int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
|
|
DataType index_type_;
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("StridedSlice"), StridedSliceOp);
|
|
|
|
class StridedSliceGradOp : public XlaOpKernel {
|
|
public:
|
|
explicit StridedSliceGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
|
|
}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
TensorShape processing_shape, final_shape;
|
|
gtl::InlinedVector<int64, 4> begin;
|
|
gtl::InlinedVector<int64, 4> end;
|
|
gtl::InlinedVector<int64, 4> strides;
|
|
|
|
TensorShape input_shape;
|
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
|
|
|
|
xla::Literal begin_literal, end_literal, strides_literal;
|
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
|
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
|
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
|
|
|
|
Tensor begin_tensor, end_tensor, strides_tensor;
|
|
OP_REQUIRES_OK(
|
|
ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
|
|
OP_REQUIRES_OK(ctx,
|
|
LiteralToHostTensor(end_literal, index_type_, &end_tensor));
|
|
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
|
&strides_tensor));
|
|
|
|
bool dummy = false;
|
|
OP_REQUIRES_OK(
|
|
ctx, ValidateStridedSliceOp(
|
|
&begin_tensor, &end_tensor, strides_tensor, input_shape,
|
|
begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
|
|
shrink_axis_mask_, &processing_shape, &final_shape, &dummy,
|
|
&dummy, &dummy, &begin, &end, &strides));
|
|
|
|
// Check to make sure dy is consistent with the original slice
|
|
const TensorShape dy_shape = ctx->InputShape(4);
|
|
OP_REQUIRES(
|
|
ctx, final_shape == dy_shape,
|
|
errors::InvalidArgument("shape of dy was ", dy_shape.DebugString(),
|
|
" instead of ", final_shape.DebugString()));
|
|
|
|
OP_REQUIRES(
|
|
ctx, input_shape.dims() == processing_shape.dims(),
|
|
errors::Internal(
|
|
"input shape and processing shape must have same number of dims"));
|
|
|
|
auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0));
|
|
|
|
xla::ComputationDataHandle grad = ctx->Input(4);
|
|
|
|
// Undo any new/shrink axes.
|
|
grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes());
|
|
|
|
// Pad the input gradients.
|
|
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
|
|
xla::PaddingConfig padding_config;
|
|
|
|
for (int i = 0; i < processing_shape.dims(); ++i) {
|
|
auto* dims = padding_config.add_dimensions();
|
|
if (strides[i] > 0) {
|
|
dims->set_edge_padding_low(begin[i]);
|
|
dims->set_interior_padding(strides[i] - 1);
|
|
|
|
// Pad the upper dimension up to the expected input shape. (It's
|
|
// not sufficient simply to use "end[i]" to compute the padding in
|
|
// cases where the stride does not divide evenly into the interval
|
|
// between begin[i] and end[i].)
|
|
int64 size =
|
|
dims->edge_padding_low() + processing_shape.dim_size(i) +
|
|
(processing_shape.dim_size(i) - 1) * dims->interior_padding();
|
|
dims->set_edge_padding_high(input_shape.dim_size(i) - size);
|
|
} else {
|
|
dimensions_to_reverse.push_back(i);
|
|
dims->set_edge_padding_high(input_shape.dim_size(i) - begin[i] - 1);
|
|
dims->set_interior_padding(-strides[i] - 1);
|
|
|
|
// Pad the lower dimension up to the expected input shape.
|
|
int64 size =
|
|
dims->edge_padding_high() + processing_shape.dim_size(i) +
|
|
(processing_shape.dim_size(i) - 1) * dims->interior_padding();
|
|
dims->set_edge_padding_low(input_shape.dim_size(i) - size);
|
|
}
|
|
}
|
|
if (!dimensions_to_reverse.empty()) {
|
|
grad = ctx->builder()->Rev(grad, dimensions_to_reverse);
|
|
}
|
|
grad = ctx->builder()->Pad(grad, zero, padding_config);
|
|
ctx->SetOutput(0, grad);
|
|
}
|
|
|
|
private:
|
|
int32 begin_mask_, end_mask_;
|
|
int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
|
|
DataType index_type_;
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp);
|
|
|
|
class StridedSliceAssignOp : public XlaOpKernel {
|
|
public:
|
|
explicit StridedSliceAssignOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
|
|
}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
TensorShape final_shape;
|
|
gtl::InlinedVector<int64, 4> begin;
|
|
gtl::InlinedVector<int64, 4> end;
|
|
gtl::InlinedVector<int64, 4> strides;
|
|
|
|
xla::Literal begin_literal, end_literal, strides_literal;
|
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
|
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
|
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
|
|
|
|
Tensor begin_tensor, end_tensor, strides_tensor;
|
|
OP_REQUIRES_OK(
|
|
ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
|
|
OP_REQUIRES_OK(ctx,
|
|
LiteralToHostTensor(end_literal, index_type_, &end_tensor));
|
|
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
|
&strides_tensor));
|
|
|
|
DataType lhs_type;
|
|
TensorShape lhs_shape;
|
|
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape));
|
|
|
|
const TensorShape rhs_shape = ctx->InputShape(4);
|
|
|
|
TensorShape dummy_processing_shape;
|
|
bool dummy = false;
|
|
OP_REQUIRES_OK(ctx,
|
|
ValidateStridedSliceOp(
|
|
&begin_tensor, &end_tensor, strides_tensor, lhs_shape,
|
|
begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
|
|
shrink_axis_mask_, &dummy_processing_shape, &final_shape,
|
|
&dummy, &dummy, &dummy, &begin, &end, &strides));
|
|
|
|
if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) {
|
|
// DynamicUpdateSlice does not allow 0-element updates. We should probably
|
|
// check that rhs_shape can be broadcast to final_shape, but that is
|
|
// probably better handled when implementing broadcasting more generally.
|
|
return;
|
|
}
|
|
|
|
// TODO(aselle): This check is too strong, we only should need
|
|
// input_shape to be broadcastable to final_shape
|
|
OP_REQUIRES(ctx, final_shape == rhs_shape,
|
|
errors::Unimplemented(
|
|
"sliced l-value shape ", final_shape.DebugString(),
|
|
" does not match r-value shape ", rhs_shape.DebugString(),
|
|
". Automatic broadcasting not yet implemented."));
|
|
|
|
xla::ComputationDataHandle lhs;
|
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs));
|
|
|
|
xla::ComputationDataHandle rhs = ctx->Input(4);
|
|
|
|
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
|
|
gtl::InlinedVector<int64, 4> slice_begin, slice_dims;
|
|
for (int i = 0; i < begin.size(); ++i) {
|
|
// TODO(phawkins): implement strides != 1
|
|
OP_REQUIRES(
|
|
ctx, strides[i] == 1 || strides[i] == -1,
|
|
errors::Unimplemented("Strides != 1 or -1 are not yet implemented"));
|
|
if (strides[i] > 0) {
|
|
slice_begin.push_back(begin[i]);
|
|
slice_dims.push_back(end[i] - begin[i]);
|
|
} else {
|
|
// Negative stride: swap begin and end, add 1 because the interval
|
|
// is semi-open, and mark the dimension to be reversed.
|
|
slice_begin.push_back(end[i] + 1);
|
|
slice_dims.push_back(begin[i] - end[i]);
|
|
dimensions_to_reverse.push_back(i);
|
|
}
|
|
}
|
|
|
|
if (!dimensions_to_reverse.empty()) {
|
|
rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse);
|
|
}
|
|
rhs = ctx->builder()->Reshape(rhs, slice_dims);
|
|
|
|
if (lhs_shape.dims() == 0) {
|
|
// TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix
|
|
// and remove this workaround.
|
|
lhs = rhs;
|
|
} else {
|
|
lhs = ctx->builder()->DynamicUpdateSlice(
|
|
lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
|
|
}
|
|
|
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs));
|
|
}
|
|
|
|
private:
|
|
int32 begin_mask_, end_mask_;
|
|
int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
|
|
DataType index_type_;
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("ResourceStridedSliceAssign"), StridedSliceAssignOp);
|
|
|
|
} // namespace
|
|
} // namespace tensorflow
|