diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 45048bd6efb..f6d6ee0710a 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -248,21 +248,36 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { } void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { + auto tf_dlm_context = GetDlContext(h, status); + if (!status->status.ok()) { + return nullptr; + } + + auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status); + if (!status->status.ok()) { + return nullptr; + } + const Tensor* tensor = GetTensorFromHandle(h, status); TF_DataType data_type = static_cast(tensor->dtype()); - TensorReference tensor_ref(*tensor); // This will call buf_->Ref() + auto tf_dlm_type = GetDlDataType(data_type, status); + if (!status->status.ok()) { + return nullptr; + } + + TensorReference tensor_ref(*tensor); // This will call buf_->Ref() auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref); tf_dlm_tensor_ctx->reference = tensor_ref; DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; dlm_tensor->deleter = &DLManagedTensorDeleter; - dlm_tensor->dl_tensor.ctx = GetDlContext(h, status); + dlm_tensor->dl_tensor.ctx = tf_dlm_context; int ndim = tensor->dims(); dlm_tensor->dl_tensor.ndim = ndim; - dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); - dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status); + dlm_tensor->dl_tensor.data = tf_dlm_data; + dlm_tensor->dl_tensor.dtype = tf_dlm_type; std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; @@ -275,13 +290,14 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; } - dlm_tensor->dl_tensor.shape = &(*shape_arr)[0]; + dlm_tensor->dl_tensor.shape = shape_arr->data(); // There are two ways to represent compact row-major data // 1) nullptr indicates tensor is compact and row-majored. // 2) fill in the strides array as the real case for compact row-major data. // Here we choose option 2, since some frameworks didn't handle the strides // argument properly. - dlm_tensor->dl_tensor.strides = &(*stride_arr)[0]; + dlm_tensor->dl_tensor.strides = stride_arr->data(); + dlm_tensor->dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here return static_cast(dlm_tensor); diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index f9c720a2ba2..1ecc0ab7a50 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.proto.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/io/path.h" @@ -72,26 +73,41 @@ uint64 GetLatencyMicroseconds(const uint64 start_microseconds) { // Ensure that constant tensors loaded from the saved model have valid shape. // Also ensure that constant nodes have a value assigned to them. // TODO(b/154763635): this is temporary and will be replaced with a better audit +static Status ValidateNode(const NodeDef& node) { + const auto node_iterator = node.attr().find("value"); + if (node_iterator != node.attr().end()) { + AttrValue node_value = node_iterator->second; + if (node_value.has_tensor()) { + const PartialTensorShape node_shape(node_value.tensor().tensor_shape()); + if (node_shape.num_elements() < 0) { + return errors::FailedPrecondition( + "Saved model contains node \"", node.name(), "\" (op \"", node.op(), + "\") which initializes from a tensor with ", + node_shape.num_elements(), " elements"); + } + } + } else if (node.op() == "Const") { + return errors::FailedPrecondition( + "Saved model contains node \"", node.name(), + "\" which is a constant tensor but no value has been provided"); + } + return Status::OK(); +} + static Status ValidateSavedTensors(const GraphDef& graph_def) { for (const auto& node : graph_def.node()) { - const auto node_iterator = node.attr().find("value"); - if (node_iterator != node.attr().end()) { - AttrValue node_value = node_iterator->second; - if (node_value.has_tensor()) { - const PartialTensorShape node_shape(node_value.tensor().tensor_shape()); - if (node_shape.num_elements() < 0) { - return errors::FailedPrecondition( - "Saved model contains node \"", node.name(), "\" (op \"", - node.op(), "\") which initializes from a tensor with ", - node_shape.num_elements(), " elements"); - } + TF_RETURN_IF_ERROR(ValidateNode(node)); + } + + if (graph_def.has_library()) { + const FunctionDefLibrary& library = graph_def.library(); + for (const auto& function : library.function()) { + for (const auto& node : function.node_def()) { + TF_RETURN_IF_ERROR(ValidateNode(node)); } - } else if (node.op() == "Const") { - return errors::FailedPrecondition( - "Saved model contains node \"", node.name(), - "\" which is a constant tensor but no value has been provided"); } } + return Status::OK(); } diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 1a56cc30510..980a75bf254 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -307,7 +307,12 @@ Status KernelAndDeviceOp::Run( if (outputs != nullptr) { outputs->clear(); for (int i = 0; i < context.num_outputs(); ++i) { - outputs->push_back(Tensor(*context.mutable_output(i))); + const auto* output_tensor = context.mutable_output(i); + if (output_tensor != nullptr) { + outputs->push_back(Tensor(*output_tensor)); + } else { + outputs->push_back(Tensor()); + } } } return Status::OK(); diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 7da864a6027..14f7d99bf2e 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6085,6 +6085,24 @@ tf_kernel_library( deps = STRING_DEPS, ) +tf_cc_test( + name = "as_string_op_test", + size = "small", + srcs = ["as_string_op_test.cc"], + deps = [ + ":as_string_op", + ":ops_testutil", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_library( name = "unicode_ops", prefix = "unicode_ops", diff --git a/tensorflow/core/kernels/as_string_op.cc b/tensorflow/core/kernels/as_string_op.cc index 8341909fbc8..b9af976a654 100644 --- a/tensorflow/core/kernels/as_string_op.cc +++ b/tensorflow/core/kernels/as_string_op.cc @@ -65,9 +65,26 @@ class AsStringOp : public OpKernel { OP_REQUIRES(ctx, !(scientific && shortest), errors::InvalidArgument( "Cannot select both scientific and shortest notation")); + format_ = "%"; + if (!fill_string.empty()) { + switch (fill_string[0]) { + case ' ': + case '+': + case '-': + case '0': + case '#': + strings::Appendf(&format_, "%s", fill_string.c_str()); + break; + default: + bool fill_not_supported = true; + OP_REQUIRES(ctx, !fill_not_supported, + errors::InvalidArgument("Fill argument not supported: \"", + fill_string, "\"")); + } + } if (width > -1) { - strings::Appendf(&format_, "%s%d", fill_string.c_str(), width); + strings::Appendf(&format_, "%d", width); } if (precision > -1) { strings::Appendf(&format_, ".%d", precision); diff --git a/tensorflow/core/kernels/as_string_op_test.cc b/tensorflow/core/kernels/as_string_op_test.cc new file mode 100644 index 00000000000..dff78e25e72 --- /dev/null +++ b/tensorflow/core/kernels/as_string_op_test.cc @@ -0,0 +1,245 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +class AsStringGraphTest : public OpsTestBase { + protected: + Status Init(DataType input_type, const string& fill = "", int width = -1, + int precision = -1, bool scientific = false, + bool shortest = false) { + TF_CHECK_OK(NodeDefBuilder("op", "AsString") + .Input(FakeInput(input_type)) + .Attr("fill", fill) + .Attr("precision", precision) + .Attr("scientific", scientific) + .Attr("shortest", shortest) + .Attr("width", width) + .Finalize(node_def())); + return InitOp(); + } +}; + +TEST_F(AsStringGraphTest, Int8) { + TF_ASSERT_OK(Init(DT_INT8)); + + AddInputFromArray(TensorShape({3}), {-42, 0, 42}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({3})); + test::FillValues(&expected, {"-42", "0", "42"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AsStringGraphTest, Int64) { + TF_ASSERT_OK(Init(DT_INT64)); + + AddInputFromArray(TensorShape({3}), {-42, 0, 42}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({3})); + test::FillValues(&expected, {"-42", "0", "42"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AsStringGraphTest, FloatDefault) { + TF_ASSERT_OK(Init(DT_FLOAT)); + + AddInputFromArray(TensorShape({4}), {-42, 0, 3.14159, 42}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({4})); + test::FillValues( + &expected, {"-42.000000", "0.000000", "3.141590", "42.000000"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AsStringGraphTest, FloatScientific) { + TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1, + /*scientific=*/true)); + + AddInputFromArray(TensorShape({4}), {-42, 0, 3.14159, 42}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({4})); + test::FillValues(&expected, {"-4.200000e+01", "0.000000e+00", + "3.141590e+00", "4.200000e+01"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AsStringGraphTest, FloatShortest) { + TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1, + /*scientific=*/false, /*shortest=*/true)); + + AddInputFromArray(TensorShape({4}), {-42, 0, 3.14159, 42}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({4})); + test::FillValues(&expected, {"-42", "0", "3.14159", "42"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AsStringGraphTest, FloatPrecisionOnly) { + TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/2)); + + AddInputFromArray(TensorShape({4}), {-42, 0, 3.14159, 42}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({4})); + test::FillValues(&expected, {"-42.00", "0.00", "3.14", "42.00"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AsStringGraphTest, FloatWidthOnly) { + TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/5)); + + AddInputFromArray(TensorShape({4}), {-42, 0, 3.14159, 42}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({4})); + test::FillValues( + &expected, {"-42.000000", "0.000000", "3.141590", "42.000000"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AsStringGraphTest, Float_5_2_Format) { + TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/5, /*precision=*/2)); + + AddInputFromArray(TensorShape({4}), {-42, 0, 3.14159, 42}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({4})); + test::FillValues(&expected, {"-42.00", " 0.00", " 3.14", "42.00"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AsStringGraphTest, Complex) { + TF_ASSERT_OK(Init(DT_COMPLEX64, /*fill=*/"", /*width=*/5, /*precision=*/2)); + + AddInputFromArray(TensorShape({3}), {{-4, 2}, {0}, {3.14159, -1}}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({3})); + test::FillValues( + &expected, {"(-4.00, 2.00)", "( 0.00, 0.00)", "( 3.14,-1.00)"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AsStringGraphTest, Bool) { + TF_ASSERT_OK(Init(DT_BOOL)); + + AddInputFromArray(TensorShape({2}), {true, false}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({2})); + test::FillValues(&expected, {"true", "false"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AsStringGraphTest, String) { + Status s = Init(DT_STRING); + ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); + ASSERT_TRUE(absl::StrContains( + s.error_message(), + "Value for attr 'T' of string is not in the list of allowed values")); +} + +TEST_F(AsStringGraphTest, OnlyOneOfScientificAndShortest) { + Status s = Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1, + /*scientific=*/true, /*shortest=*/true); + ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); + ASSERT_TRUE( + absl::StrContains(s.error_message(), + "Cannot select both scientific and shortest notation")); +} + +TEST_F(AsStringGraphTest, NoShortestForNonFloat) { + Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1, + /*scientific=*/false, /*shortest=*/true); + ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); + ASSERT_TRUE(absl::StrContains( + s.error_message(), + "scientific and shortest format not supported for datatype")); +} + +TEST_F(AsStringGraphTest, NoScientificForNonFloat) { + Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1, + /*scientific=*/true); + ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); + ASSERT_TRUE(absl::StrContains( + s.error_message(), + "scientific and shortest format not supported for datatype")); +} + +TEST_F(AsStringGraphTest, NoPrecisionForNonFloat) { + Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/5); + ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); + ASSERT_TRUE(absl::StrContains(s.error_message(), + "precision not supported for datatype")); +} + +TEST_F(AsStringGraphTest, LongFill) { + Status s = Init(DT_INT32, /*fill=*/"asdf"); + ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); + ASSERT_TRUE(absl::StrContains(s.error_message(), + "Fill string must be one or fewer characters")); +} + +TEST_F(AsStringGraphTest, FillWithZero) { + TF_ASSERT_OK(Init(DT_INT64, /*fill=*/"0", /*width=*/4)); + + AddInputFromArray(TensorShape({3}), {-42, 0, 42}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({3})); + test::FillValues(&expected, {"-042", "0000", "0042"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AsStringGraphTest, FillWithSpace) { + TF_ASSERT_OK(Init(DT_INT64, /*fill=*/" ", /*width=*/4)); + + AddInputFromArray(TensorShape({3}), {-42, 0, 42}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({3})); + test::FillValues(&expected, {" -42", " 0", " 42"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AsStringGraphTest, FillWithChar1) { + TF_ASSERT_OK(Init(DT_INT64, /*fill=*/"-", /*width=*/4)); + + AddInputFromArray(TensorShape({3}), {-42, 0, 42}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({3})); + test::FillValues(&expected, {"-42 ", "0 ", "42 "}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(AsStringGraphTest, FillWithChar3) { + Status s = Init(DT_INT32, /*fill=*/"s"); + ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); + ASSERT_TRUE( + absl::StrContains(s.error_message(), "Fill argument not supported")); +} + +TEST_F(AsStringGraphTest, FillWithChar4) { + Status s = Init(DT_INT32, /*fill=*/"n"); + ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); + ASSERT_TRUE( + absl::StrContains(s.error_message(), "Fill argument not supported")); +} + +} // end namespace +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/banded_triangular_solve_op.cc b/tensorflow/core/kernels/banded_triangular_solve_op.cc index d01a015502a..666282e52c8 100644 --- a/tensorflow/core/kernels/banded_triangular_solve_op.cc +++ b/tensorflow/core/kernels/banded_triangular_solve_op.cc @@ -193,7 +193,8 @@ struct LaunchBatchBandedTriangularSolve { Shard(worker_threads.num_threads, worker_threads.workers, batch_size, cost_per_unit, - [&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) { + [&in_x, &in_y, adjoint, lower, &bcast, out](int64 start, + int64 limit) { SequentialBandedTriangularSolveKernel::Run( in_x, in_y, lower, adjoint, bcast, out, start, limit); }); diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc index 19be606f184..e3a908d1b6b 100644 --- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc @@ -121,7 +121,7 @@ class BoostedTreesTrainingPredictOp : public OpKernel { auto do_work = [&resource, &bucketized_features, &cached_tree_ids, &cached_node_ids, &output_partial_logits, &output_node_ids, latest_tree, - this](int32 start, int32 end) { + this](int64 start, int64 end) { for (int32 i = start; i < end; ++i) { int32 tree_id = cached_tree_ids(i); int32 node_id = cached_node_ids(i); @@ -237,7 +237,7 @@ class BoostedTreesPredictOp : public OpKernel { const int32 last_tree = resource->num_trees() - 1; auto do_work = [&resource, &bucketized_features, &output_logits, last_tree, - this](int32 start, int32 end) { + this](int64 start, int64 end) { for (int32 i = start; i < end; ++i) { std::vector tree_logits(logits_dimension_, 0.0); int32 tree_id = 0; @@ -340,7 +340,7 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel { // path. Note: feature_ids has one less value than logits_path because the // first value of each logit path will be the bias. auto do_work = [&resource, &bucketized_features, &output_debug_info, - last_tree](int32 start, int32 end) { + last_tree](int64 start, int64 end) { for (int32 i = start; i < end; ++i) { // Proto to store debug outputs, per example. boosted_trees::DebugOutput example_debug_info; diff --git a/tensorflow/core/kernels/count_ops.cc b/tensorflow/core/kernels/count_ops.cc index 7c85b050039..087deef0812 100644 --- a/tensorflow/core/kernels/count_ops.cc +++ b/tensorflow/core/kernels/count_ops.cc @@ -178,10 +178,30 @@ class SparseCount : public OpKernel { const Tensor& weights = context->input(3); bool use_weights = weights.NumElements() > 0; + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()), + errors::InvalidArgument( + "Input indices must be a 2-dimensional tensor. Got: ", + indices.shape().DebugString())); + + if (use_weights) { + OP_REQUIRES( + context, weights.shape() == values.shape(), + errors::InvalidArgument( + "Weights and values must have the same shape. Weight shape: ", + weights.shape().DebugString(), + "; values shape: ", values.shape().DebugString())); + } + bool is_1d = shape.NumElements() == 1; int num_batches = is_1d ? 1 : shape.flat()(0); int num_values = values.NumElements(); + OP_REQUIRES(context, num_values == indices.shape().dim_size(0), + errors::InvalidArgument( + "Number of values must match first dimension of indices.", + "Got ", num_values, + " values, indices shape: ", indices.shape().DebugString())); + const auto indices_values = indices.matrix(); const auto values_values = values.flat(); const auto weight_values = weights.flat(); @@ -235,12 +255,33 @@ class RaggedCount : public OpKernel { bool use_weights = weights.NumElements() > 0; bool is_1d = false; + if (use_weights) { + OP_REQUIRES( + context, weights.shape() == values.shape(), + errors::InvalidArgument( + "Weights and values must have the same shape. Weight shape: ", + weights.shape().DebugString(), + "; values shape: ", values.shape().DebugString())); + } + const auto splits_values = splits.flat(); const auto values_values = values.flat(); const auto weight_values = weights.flat(); int num_batches = splits.NumElements() - 1; int num_values = values.NumElements(); + OP_REQUIRES( + context, num_batches > 0, + errors::InvalidArgument( + "Must provide at least 2 elements for the splits argument")); + OP_REQUIRES(context, splits_values(0) == 0, + errors::InvalidArgument("Splits must start with 0, not with ", + splits_values(0))); + OP_REQUIRES(context, splits_values(num_batches) == num_values, + errors::InvalidArgument( + "Splits must end with the number of values, got ", + splits_values(num_batches), " instead of ", num_values)); + auto per_batch_counts = BatchedMap(num_batches); T max_value = 0; int batch_idx = 0; diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc index 23058788a4b..4ecd3bc0a79 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/crop_and_resize_op.cc @@ -223,7 +223,7 @@ struct CropAndResize { const int depth = crops.dimension(3); // Sharding across boxes. - auto CropAndResizePerBox = [&](int start_box, int limit_box) { + auto CropAndResizePerBox = [&](int64 start_box, int64 limit_box) { for (int b = start_box; b < limit_box; ++b) { const float y1 = boxes(b, 0); const float x1 = boxes(b, 1); @@ -449,7 +449,7 @@ struct CropAndResizeBackpropImage { grads_image.setZero(); - auto CropAndResizeBackImgPerBox = [&](int start_box, int limit_box) { + auto CropAndResizeBackImgPerBox = [&](int64 start_box, int64 limit_box) { for (int b = start_box; b < limit_box; ++b) { const float y1 = boxes(b, 0); const float x1 = boxes(b, 1); diff --git a/tensorflow/core/kernels/nth_element_op.cc b/tensorflow/core/kernels/nth_element_op.cc index 0e43cc19aae..bd523f51e27 100644 --- a/tensorflow/core/kernels/nth_element_op.cc +++ b/tensorflow/core/kernels/nth_element_op.cc @@ -95,7 +95,8 @@ struct NthElementFunctor { const int last_dim = input_tensor.dim_size(input_tensor.dims() - 1); // Allocate each row to different shard. - auto SubNthElement = [&, input, output, last_dim, n](int start, int limit) { + auto SubNthElement = [&, input, output, last_dim, n](int64 start, + int64 limit) { // std::nth_element would rearrange the array, so we need a new buffer. std::vector buf(last_dim); diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc index ba1fd280ce7..a63457551ac 100644 --- a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc +++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc @@ -70,8 +70,8 @@ struct TruncatedNormalFunctor { auto do_work = [samples_per_batch, num_elements, &ctx, &means, &stddevs, &minvals, &maxvals, &gen, &output, - kStdDevsInsideBoundsToUseRandnSampler](int start_batch, - int limit_batch) { + kStdDevsInsideBoundsToUseRandnSampler](int64 start_batch, + int64 limit_batch) { // Capturing "gen" by-value would only make a copy for the _shared_ // lambda. Since we want to let each worker have its own copy, we pass // "gen" by reference and explicitly do a copy assignment here. @@ -333,8 +333,8 @@ struct TruncatedNormalFunctorV2 { auto do_work = [num_batches, samples_per_batch, &ctx, &bcast, &means, &stddevs, &minvals, &maxvals, &gen, &output, - kStdDevsInsideBoundsToUseRandnSampler](int start_output, - int limit_output) { + kStdDevsInsideBoundsToUseRandnSampler](int64 start_output, + int64 limit_output) { // Capturing "gen" by-value would only make a copy for the _shared_ // lambda. Since we want to let each worker have its own copy, we pass // "gen" by reference and explicitly do a copy assignment here. diff --git a/tensorflow/core/kernels/random_binomial_op.cc b/tensorflow/core/kernels/random_binomial_op.cc index 4647457ff6f..4a730fc70f7 100644 --- a/tensorflow/core/kernels/random_binomial_op.cc +++ b/tensorflow/core/kernels/random_binomial_op.cc @@ -182,7 +182,7 @@ struct RandomBinomialFunctor { // the sample shape and [H1, ... Hm] for the batch shape of the samples. // We have B1 * ... * Bk samples per batch member we need. auto DoWork = [num_batches, samples_per_batch, &bcast, &counts, &probs, - &gen, &output](int start_output, int limit_output) { + &gen, &output](int64 start_output, int64 limit_output) { // Vectorized intermediate calculations for uniform rejection sampling. // We always generate at most 4 samples. Eigen::array z; diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 152ab5f7d1e..0f32e759019 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -205,7 +205,7 @@ class RandomGammaOp : public OpKernel { // avoid a couple flops which can be done on a per-alpha basis. auto DoWork = [samples_per_alpha, num_alphas, &rng, samples_flat, - alpha_flat](int start_output, int limit_output) { + alpha_flat](int64 start_output, int64 limit_output) { using Eigen::numext::exp; using Eigen::numext::log; using Eigen::numext::log1p; diff --git a/tensorflow/core/kernels/random_poisson_op.cc b/tensorflow/core/kernels/random_poisson_op.cc index aa9a0bfe214..dcb7d6b0f0e 100644 --- a/tensorflow/core/kernels/random_poisson_op.cc +++ b/tensorflow/core/kernels/random_poisson_op.cc @@ -97,7 +97,7 @@ struct PoissonFunctor { typedef random::UniformDistribution Uniform; auto DoWork = [num_samples, num_rate, &rng, samples_flat, rate_flat]( - int start_output, int limit_output) { + int64 start_output, int64 limit_output) { // Capturing "rng" by value would only make a copy for the _shared_ // lambda. Since we want to let each worker have its own copy, we pass // "rng" by reference and explicitly do a copy assignment. diff --git a/tensorflow/core/kernels/session_ops.cc b/tensorflow/core/kernels/session_ops.cc index d83a714452f..e7e73549bc3 100644 --- a/tensorflow/core/kernels/session_ops.cc +++ b/tensorflow/core/kernels/session_ops.cc @@ -16,6 +16,7 @@ limitations under the License. // See docs in ../ops/data_flow_ops.cc. #include + #include #include "tensorflow/core/common_runtime/device.h" @@ -27,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -42,7 +44,11 @@ class GetSessionHandleOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& val = ctx->input(0); - int64 id = ctx->session_state()->GetNewId(); + auto session_state = ctx->session_state(); + OP_REQUIRES(ctx, session_state != nullptr, + errors::FailedPrecondition( + "GetSessionHandle called on null session state")); + int64 id = session_state->GetNewId(); TensorStore::TensorAndKey tk{val, id, requested_device()}; OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk)); diff --git a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc index 8de93cf9b30..542069ccd88 100644 --- a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc +++ b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc @@ -232,6 +232,9 @@ class SparseFillEmptyRowsGradOp : public OpKernel { context, TensorShapeUtils::IsVector(reverse_index_map_t->shape()), errors::InvalidArgument("reverse_index_map must be a vector, saw: ", reverse_index_map_t->shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(grad_values_t->shape()), + errors::InvalidArgument("grad_values must be a vector, saw: ", + grad_values_t->shape().DebugString())); const auto reverse_index_map = reverse_index_map_t->vec(); const auto grad_values = grad_values_t->vec(); @@ -260,8 +263,13 @@ class SparseFillEmptyRowsGradOp : public OpKernel { // Locate the index of the output of the forward prop associated // with this location in the input of the forward prop. Copy // the gradient into it. Mark it as visited. - d_values(i) = grad_values(reverse_index_map(i)); - visited(reverse_index_map(i)) = true; + int64 reverse_index = reverse_index_map(i); + OP_REQUIRES( + context, 0 <= reverse_index && reverse_index < N_full, + errors::InvalidArgument("Elements in reverse index must be in [0, ", + N_full, ") but got ", reverse_index)); + d_values(i) = grad_values(reverse_index); + visited(reverse_index) = true; } for (int j = 0; j < N_full; ++j) { // The default value gradient gets the accumulated remainder of diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc index 6738a34e3fd..3150f168828 100644 --- a/tensorflow/core/kernels/stateless_random_ops.cc +++ b/tensorflow/core/kernels/stateless_random_ops.cc @@ -252,7 +252,7 @@ class StatelessRandomGammaOp : public StatelessRandomOpBase { // avoid a couple flops which can be done on a per-alpha basis. auto DoWork = [samples_per_alpha, num_alphas, &random, samples_flat, - alpha_flat](int start_output, int limit_output) { + alpha_flat](int64 start_output, int64 limit_output) { // Capturing "random" by-value would only make a copy for the _shared_ // lambda. Since we want to let each worker have its own copy, we pass // "random" by reference and explicitly do a copy assignment. diff --git a/tensorflow/core/kernels/string_ngrams_op.cc b/tensorflow/core/kernels/string_ngrams_op.cc index 97b32c4242c..8aed2b3831a 100644 --- a/tensorflow/core/kernels/string_ngrams_op.cc +++ b/tensorflow/core/kernels/string_ngrams_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace text { @@ -60,6 +61,18 @@ class StringNGramsOp : public tensorflow::OpKernel { OP_REQUIRES_OK(context, context->input("data_splits", &splits)); const auto& splits_vec = splits->flat(); + // Validate that the splits are valid indices into data + const int input_data_size = data->flat().size(); + const int splits_vec_size = splits_vec.size(); + for (int i = 0; i < splits_vec_size; ++i) { + bool valid_splits = splits_vec(i) >= 0; + valid_splits = valid_splits && (splits_vec(i) <= input_data_size); + OP_REQUIRES( + context, valid_splits, + errors::InvalidArgument("Invalid split value ", splits_vec(i), + ", must be in [0,", input_data_size, "]")); + } + int num_batch_items = splits_vec.size() - 1; tensorflow::Tensor* ngrams_splits; OP_REQUIRES_OK( diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc index c555b42f005..e2659bbf9d5 100644 --- a/tensorflow/core/kernels/topk_op.cc +++ b/tensorflow/core/kernels/topk_op.cc @@ -136,7 +136,7 @@ struct TopKFunctor { return Status::OK(); } - auto SortIndices = [&](int start_batch, int limit_batch) { + auto SortIndices = [&](int64 start_batch, int64 limit_batch) { for (int32 b = start_batch; b < limit_batch; ++b) { const T* input_data = &input(b, 0); const auto stable_comp = [input_data](const int32 a, const int32 b) { diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 0f11af51488..00a37815d21 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/lite/arena_planner.h" +#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/api/tensor_utils.h" @@ -567,6 +568,33 @@ TfLiteStatus Subgraph::CheckTensorIndices(const char* label, const int* indices, return kTfLiteOk; } +// We have two arrays and we need to check that elements from one array don't +// show up in the other. We could sort both arrays and then iterate with two +// pointers from start to finish always increasing the smaller one but since +// these arrays are usually short (<25 elements for inputs, usually <3 for +// outputs), this might be slower than the naive approach (if arrays have size n +// and m, with n >> m ~ O(1), first approach is O(nlogn) whereas the other is +// O(n)). Plus, sorting the input and output arrays might not be something we +// want as it destroys ordering of elements. +// +// If it turns out that this is an issue, we can switch to the other algorithm. +TfLiteStatus Subgraph::CheckInputAndOutputForOverlap(const int* input_indices, + int num_inputs, + const int* output_indices, + int num_outputs) { + for (int i = 0; i < num_inputs; i++) { + for (int j = 0; j < num_outputs; j++) { + if (input_indices[i] == output_indices[j]) { + ReportError("Tensor %d is both input %d and output %d\n", + input_indices[i], i, j); + consistent_ = false; + return kTfLiteError; + } + } + } + return kTfLiteOk; +} + namespace { // Multiply two sizes and return true if overflow occurred; // This is based off tensorflow/overflow.h but is simpler as we already @@ -688,6 +716,16 @@ TfLiteStatus Subgraph::AddNodeWithParameters( &context_, CheckTensorIndices("node outputs", outputs.data(), outputs.size())); + // For builtin ops, inputs and outputs must not overlap. Custom ops must do + // this check by themselves if they don't support overlapping tensors. This + // distinction is to allow custom ops to just forward a tensor, reusing it as + // both input and output. + if (builtin_data != nullptr) { + TF_LITE_ENSURE_OK(&context_, CheckInputAndOutputForOverlap( + inputs.data(), inputs.size(), + outputs.data(), outputs.size())); + } + int new_node_index = nodes_and_registration_.size(); if (node_index) *node_index = new_node_index; nodes_and_registration_.resize(nodes_and_registration_.size() + 1); @@ -934,6 +972,19 @@ TfLiteStatus Subgraph::Invoke() { tensor->data_is_stale) { TF_LITE_ENSURE_STATUS(EnsureTensorDataIsReadable(tensor_index)); } + if (tensor->data.raw == nullptr && tensor->bytes > 0) { + if (registration.builtin_code == kTfLiteBuiltinReshape && i == 1) { + // In general, having a tensor here with no buffer will be an error. + // However, for the reshape operator, the second input tensor is only + // used for the shape, not for the data. Thus, null buffer is ok. + continue; + } else { + // In all other cases, we need to return an error as otherwise we will + // trigger a null pointer dereference (likely). + ReportError("Input tensor %d lacks data", tensor_index); + return kTfLiteError; + } + } } if (check_cancelled_func_ != nullptr && diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index bee13c9073e..979c709614c 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -433,6 +433,15 @@ class Subgraph { TfLiteStatus CheckTensorIndices(const char* label, const int* indices, int length); + // Check that the input indices and the output indices don't overlap. + // This is needed because same tensor must not be used both as input and + // output for an operator. + // NOTE: this changes consistent_ to be false if indices are out of bounds. + TfLiteStatus CheckInputAndOutputForOverlap(const int* input_indices, + int num_inputs, + const int* output_indices, + int num_outputs); + // Compute the number of bytes required to represent a tensor with dimensions // specified by the array dims (of length dims_size). Returns the status code // and bytes. diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc index 4b491d41881..3c457523ca6 100644 --- a/tensorflow/lite/interpreter_builder.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -609,7 +609,12 @@ TfLiteStatus InterpreterBuilder::operator()( auto* buffers = model_->buffers(); if (subgraphs->size() == 0) { - error_reporter_->Report("No subgraph in the model.\n"); + TF_LITE_REPORT_ERROR(error_reporter_, "No subgraph in the model.\n"); + return cleanup_and_error(); + } + + if (!buffers) { + TF_LITE_REPORT_ERROR(error_reporter_, "No buffers in the model.\n"); return cleanup_and_error(); } @@ -630,10 +635,10 @@ TfLiteStatus InterpreterBuilder::operator()( (*interpreter)->subgraph(subgraph_index); auto operators = subgraph->operators(); auto tensors = subgraph->tensors(); - if (!operators || !tensors || !buffers) { - error_reporter_->Report( - "Did not get operators, tensors, or buffers in subgraph %d.\n", - subgraph_index); + if (!operators || !tensors) { + TF_LITE_REPORT_ERROR(error_reporter_, + "Did not get operators or tensors in subgraph %d.\n", + subgraph_index); return cleanup_and_error(); } if (modified_subgraph->AddTensors(tensors->size()) != kTfLiteOk) { diff --git a/tensorflow/lite/kernels/internal/reference/reduce.h b/tensorflow/lite/kernels/internal/reference/reduce.h index fbad266e843..8291141618f 100644 --- a/tensorflow/lite/kernels/internal/reference/reduce.h +++ b/tensorflow/lite/kernels/internal/reference/reduce.h @@ -70,6 +70,9 @@ inline bool ResolveAxis(const int num_dims, const int* axis, // eg: For num_dims=3, [0, 1, 2] is the same as [-3, -2, -1] */ int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx]; TFLITE_DCHECK(current >= 0 && current < num_dims); + if (current < 0 || current >= num_dims) { + return false; + } bool is_dup = false; for (int j = 0; j < *out_num_axis; ++j) { if (out_axis[j] == current) { diff --git a/tensorflow/lite/kernels/internal/types.h b/tensorflow/lite/kernels/internal/types.h index 2a34f6608a3..adbd34b0146 100644 --- a/tensorflow/lite/kernels/internal/types.h +++ b/tensorflow/lite/kernels/internal/types.h @@ -432,7 +432,7 @@ int MatchingArraySize(const ArrayType1& array1, int index1, inline int MatchingDim(const RuntimeShape& shape1, int index1, const RuntimeShape& shape2, int index2) { TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2)); - return shape1.Dims(index1); + return std::min(shape1.Dims(index1), shape2.Dims(index2)); } template diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h index 6bd6bb1c7ed..59b1974c3b9 100644 --- a/tensorflow/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -30,27 +30,48 @@ inline int SizeOfDimension(const TfLiteTensor* t, int dim) { } inline const TfLiteTensor* GetInput(const TfLiteContext* context, const TfLiteNode* node, int index) { - return &context->tensors[node->inputs->data[index]]; + const int tensor_index = node->inputs->data[index]; + if (tensor_index < 0) { + return nullptr; + } + return &context->tensors[tensor_index]; } // Note: You must check if result is not null: // TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx); // TF_LITE_ENSURE(context, my_tensor != nullptr); inline TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node, int index) { - TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; + const int tensor_index = node->inputs->data[index]; + if (tensor_index < 0) { + return nullptr; + } + TfLiteTensor* tensor = &context->tensors[tensor_index]; return (tensor->is_variable) ? tensor : nullptr; } inline TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node, int index) { - return &context->tensors[node->outputs->data[index]]; + const int tensor_index = node->outputs->data[index]; + if (tensor_index < 0) { + return nullptr; + } + return &context->tensors[tensor_index]; } inline TfLiteTensor* GetTemporary(TfLiteContext* context, const TfLiteNode* node, int index) { - return &context->tensors[node->temporaries->data[index]]; + const int tensor_index = node->temporaries->data[index]; + if (tensor_index < 0) { + return nullptr; + } + return &context->tensors[tensor_index]; } + inline const TfLiteTensor* GetIntermediates(TfLiteContext* context, const TfLiteNode* node, int index) { - return &context->tensors[node->intermediates->data[index]]; + const int tensor_index = node->intermediates->data[index]; + if (tensor_index < 0) { + return nullptr; + } + return &context->tensors[tensor_index]; } inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; } inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; } @@ -73,12 +94,7 @@ inline int64_t NumElements(const TfLiteTensor* t) { inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, const TfLiteNode* node, int index) { - const bool use_tensor = index < node->inputs->size && - node->inputs->data[index] != kTfLiteOptionalTensor; - if (use_tensor) { - return &context->tensors[node->inputs->data[index]]; - } - return nullptr; + return GetInput(context, node, index); } // Determines whether tensor is constant. diff --git a/tensorflow/lite/kernels/segment_sum.cc b/tensorflow/lite/kernels/segment_sum.cc index 8185359321e..4b762184a50 100644 --- a/tensorflow/lite/kernels/segment_sum.cc +++ b/tensorflow/lite/kernels/segment_sum.cc @@ -34,11 +34,24 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, const TfLiteTensor* data, const TfLiteTensor* segment_ids, TfLiteTensor* output) { - int max_index = -1; + // Segment ids should be of same cardinality as first input dimension and they + // should be increasing by at most 1, from 0 (e.g., [0, 0, 1, 2, 3] is valid) const int segment_id_size = segment_ids->dims->data[0]; - if (segment_id_size > 0) { - max_index = segment_ids->data.i32[segment_id_size - 1]; + TF_LITE_ENSURE_EQ(context, segment_id_size, data->dims->data[0]); + int previous_segment_id = -1; + for (int i = 0; i < segment_id_size; i++) { + const int current_segment_id = GetTensorData(segment_ids)[i]; + if (i == 0) { + TF_LITE_ENSURE_EQ(context, current_segment_id, 0); + } else { + int delta = current_segment_id - previous_segment_id; + TF_LITE_ENSURE(context, delta == 0 || delta == 1); + } + previous_segment_id = current_segment_id; } + + const int max_index = previous_segment_id; + const int data_rank = NumDimensions(data); TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data)); output_shape->data[0] = max_index + 1; diff --git a/tensorflow/lite/kernels/segment_sum_test.cc b/tensorflow/lite/kernels/segment_sum_test.cc index ec531ffd92d..286742c0933 100644 --- a/tensorflow/lite/kernels/segment_sum_test.cc +++ b/tensorflow/lite/kernels/segment_sum_test.cc @@ -110,5 +110,37 @@ TEST(SegmentSumOpModelTest, Float32Test_ThreeDimensions) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 1})); } +TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotSorted) { + SegmentSumOpModel model({TensorType_INT32, {3, 2}}, + {TensorType_INT32, {3}}); + model.PopulateTensor(model.data(), {1, 2, 3, 4, 5, 6}); + model.PopulateTensor(model.segment_ids(), {0, 3, 1}); + ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError); +} + +TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotConsecutive) { + SegmentSumOpModel model({TensorType_INT32, {3, 2}}, + {TensorType_INT32, {3}}); + model.PopulateTensor(model.data(), {1, 2, 3, 4, 5, 6}); + model.PopulateTensor(model.segment_ids(), {0, 3, 5}); + ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError); +} + +TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNegative) { + SegmentSumOpModel model({TensorType_INT32, {3, 2}}, + {TensorType_INT32, {3}}); + model.PopulateTensor(model.data(), {1, 2, 3, 4, 5, 6}); + model.PopulateTensor(model.segment_ids(), {-1, 0, 1}); + ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError); +} + +TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotTheRightCardinality) { + SegmentSumOpModel model({TensorType_INT32, {3, 2}}, + {TensorType_INT32, {2}}); + model.PopulateTensor(model.data(), {1, 2, 3, 4, 5, 6}); + model.PopulateTensor(model.segment_ids(), {0, 1}); + ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError); +} + } // namespace } // namespace tflite diff --git a/tensorflow/python/dlpack/dlpack_test.py b/tensorflow/python/dlpack/dlpack_test.py index af91da80512..df53220849c 100644 --- a/tensorflow/python/dlpack/dlpack_test.py +++ b/tensorflow/python/dlpack/dlpack_test.py @@ -20,9 +20,11 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np + from tensorflow.python.dlpack import dlpack from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -95,6 +97,12 @@ class DLPackTest(parameterized.TestCase, test.TestCase): self.assertRaisesRegex(Exception, ".* is not supported by dlpack", UnsupportedComplex64) + def testMustPassTensorArgumentToDLPack(self): + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "The argument to `to_dlpack` must be a TF tensor, not Python object"): + dlpack.to_dlpack([1]) + if __name__ == "__main__": ops.enable_eager_execution() diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index eec7165d148..0f1485515fb 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -4581,6 +4581,14 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): result = control_flow_ops.merge([v_f, v_t]) self.evaluate(result) + def testSwitchEagerMode(self): + if not context.executing_eagerly(): + return + input_data = [1, 2, 3, 4] + vf, vt = control_flow_ops.switch(input_data, False) + self.assertAllEqual(vf, input_data) + self.assertAllEqual(vt, []) + @test_util.run_deprecated_v1 def testQIntArgAndRet(self): diff --git a/tensorflow/python/ops/bincount_ops_test.py b/tensorflow/python/ops/bincount_ops_test.py index 74fd17cae2b..e9906e32f95 100644 --- a/tensorflow/python/ops/bincount_ops_test.py +++ b/tensorflow/python/ops/bincount_ops_test.py @@ -25,7 +25,9 @@ from tensorflow.python.eager import context from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import bincount_ops +from tensorflow.python.ops import gen_count_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -834,5 +836,121 @@ class TestSparseCountFailureModes(test.TestCase): self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) +@test_util.run_all_in_graph_and_eager_modes +@test_util.disable_tfrt +class RawOpsTest(test.TestCase, parameterized.TestCase): + + def testSparseCountSparseOutputBadIndicesShape(self): + indices = [[[0], [0]], [[0], [1]], [[1], [0]], [[1], [2]]] + values = [1, 1, 1, 10] + weights = [1, 2, 4, 6] + dense_shape = [2, 3] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Input indices must be a 2-dimensional tensor"): + self.evaluate( + gen_count_ops.SparseCountSparseOutput( + indices=indices, + values=values, + dense_shape=dense_shape, + weights=weights, + binary_output=False)) + + def testSparseCountSparseOutputBadWeightsShape(self): + indices = [[0, 0], [0, 1], [1, 0], [1, 2]] + values = [1, 1, 1, 10] + weights = [1, 2, 4] + dense_shape = [2, 3] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Weights and values must have the same shape"): + self.evaluate( + gen_count_ops.SparseCountSparseOutput( + indices=indices, + values=values, + dense_shape=dense_shape, + weights=weights, + binary_output=False)) + + def testSparseCountSparseOutputBadNumberOfValues(self): + indices = [[0, 0], [0, 1], [1, 0]] + values = [1, 1, 1, 10] + weights = [1, 2, 4, 6] + dense_shape = [2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Number of values must match first dimension of indices"): + self.evaluate( + gen_count_ops.SparseCountSparseOutput( + indices=indices, + values=values, + dense_shape=dense_shape, + weights=weights, + binary_output=False)) + + def testRaggedCountSparseOutput(self): + splits = [0, 4, 7] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6, 7] + output_indices, output_values, output_shape = self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, values=values, weights=weights, binary_output=False)) + self.assertAllEqual([[0, 1], [0, 2], [1, 2], [1, 5], [1, 10]], + output_indices) + self.assertAllEqual([7, 3, 5, 7, 6], output_values) + self.assertAllEqual([2, 11], output_shape) + + def testRaggedCountSparseOutputBadWeightsShape(self): + splits = [0, 4, 7] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Weights and values must have the same shape"): + self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, + values=values, + weights=weights, + binary_output=False)) + + def testRaggedCountSparseOutputEmptySplits(self): + splits = [] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6, 7] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Must provide at least 2 elements for the splits argument"): + self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, + values=values, + weights=weights, + binary_output=False)) + + def testRaggedCountSparseOutputBadSplitsStart(self): + splits = [1, 7] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6, 7] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Splits must start with 0"): + self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, + values=values, + weights=weights, + binary_output=False)) + + def testRaggedCountSparseOutputBadSplitsEnd(self): + splits = [0, 5] + values = [1, 1, 2, 1, 2, 10, 5] + weights = [1, 2, 3, 4, 5, 6, 7] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Splits must end with the number of values"): + self.evaluate( + gen_count_ops.RaggedCountSparseOutput( + splits=splits, + values=values, + weights=weights, + binary_output=False)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/raw_ops_test.py b/tensorflow/python/ops/raw_ops_test.py index fff94f5c25a..0dbd7dcb916 100644 --- a/tensorflow/python/ops/raw_ops_test.py +++ b/tensorflow/python/ops/raw_ops_test.py @@ -18,16 +18,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.eager import context from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_string_ops from tensorflow.python.platform import test @test_util.run_all_in_graph_and_eager_modes -class RawOpsTest(test.TestCase): +@test_util.disable_tfrt +class RawOpsTest(test.TestCase, parameterized.TestCase): def testSimple(self): x = constant_op.constant(1) @@ -58,6 +64,29 @@ class RawOpsTest(test.TestCase): gen_math_ops.Any(input=x, axis=0), gen_math_ops.Any(input=x, axis=0, keep_dims=False)) + @parameterized.parameters([[0, 8]], [[-1, 6]]) + def testStringNGramsBadDataSplits(self, splits): + data = ["aa", "bb", "cc", "dd", "ee", "ff"] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Invalid split value"): + self.evaluate( + gen_string_ops.string_n_grams( + data=data, + data_splits=splits, + separator="", + ngram_widths=[2], + left_pad="", + right_pad="", + pad_width=0, + preserve_short_sequences=False)) + + def testGetSessionHandle(self): + if context.executing_eagerly(): + with self.assertRaisesRegex( + errors.FailedPreconditionError, + "GetSessionHandle called on null session state"): + gen_data_flow_ops.GetSessionHandle(value=[1]) + if __name__ == "__main__": ops.enable_eager_execution() diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/ops/sparse_ops_test.py index 91151ba8461..0b014b55d10 100644 --- a/tensorflow/python/ops/sparse_ops_test.py +++ b/tensorflow/python/ops/sparse_ops_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -29,6 +30,7 @@ from tensorflow.python.framework import test_util # Need array_grad to register gradient for Identity. from tensorflow.python.ops import array_grad # pylint: disable=unused-import from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_sparse_ops from tensorflow.python.ops import gradient_checker_v2 as gradient_checker from tensorflow.python.ops import math_ops # Need sparse_grad to register gradient for SparseToDense. @@ -181,5 +183,57 @@ class SparseOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertAllEqual(expected, result) +@test_util.run_all_in_graph_and_eager_modes +class RawOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): + + def testSparseFillEmptyRowsGrad(self): + reverse_index_map = [2, 1] + grad_values = [0, 1, 2, 3] + d_values, d_default_value = self.evaluate( + gen_sparse_ops.SparseFillEmptyRowsGrad( + reverse_index_map=reverse_index_map, grad_values=grad_values)) + self.assertAllEqual([2, 1], d_values) + self.assertEqual(3, d_default_value) + + def testSparseFillEmptyRowsGradNegativeIndexMapValue(self): + reverse_index_map = [2, -1] + grad_values = [0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + r'Elements in reverse index must be in \[0, 4\)'): + self.evaluate( + gen_sparse_ops.SparseFillEmptyRowsGrad( + reverse_index_map=reverse_index_map, grad_values=grad_values)) + + def testSparseFillEmptyRowsGradLargeIndexMapValue(self): + reverse_index_map = [2, 10] + grad_values = [0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + r'Elements in reverse index must be in \[0, 4\)'): + self.evaluate( + gen_sparse_ops.SparseFillEmptyRowsGrad( + reverse_index_map=reverse_index_map, grad_values=grad_values)) + + def testSparseFillEmptyRowsGradMatrix(self): + reverse_index_map = [0, 1] + grad_values = [[0, 1], [2, 3]] + # Note: Eager mode and graph mode throw different errors here. Graph mode + # will fail with a ValueError from the shape checking logic, while Eager + # will fail with an InvalidArgumentError from the kernel itself. + if context.executing_eagerly(): + with self.assertRaisesRegex(errors.InvalidArgumentError, + r'grad_values must be a vector'): + self.evaluate( + gen_sparse_ops.SparseFillEmptyRowsGrad( + reverse_index_map=reverse_index_map, grad_values=grad_values)) + else: + with self.assertRaisesRegex(ValueError, + r'Shape must be rank 1 but is rank 2'): + self.evaluate( + gen_sparse_ops.SparseFillEmptyRowsGrad( + reverse_index_map=reverse_index_map, grad_values=grad_values)) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 88bb66f189b..3401020ae99 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -1129,9 +1129,16 @@ PYBIND11_MODULE(_pywrap_tfe, m) { // DLPack functions m.def("TFE_ToDlpackCapsule", [](py::handle& o) { PyObject* eager_tensor_pyobject_ptr = o.ptr(); - TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr); tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); + + if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) { + status->status = tensorflow::errors::InvalidArgument( + "The argument to `to_dlpack` must be a TF tensor, not Python object"); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + } + + TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr); void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get()); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());