Merge pull request #43358 from tensorflow/mm-patch-r2.3
Patch for TF 2.3.1
This commit is contained in:
commit
9cf3773b71
tensorflow
c/eager
cc/saved_model
core
common_runtime/eager
kernels
BUILDas_string_op.ccas_string_op_test.ccbanded_triangular_solve_op.cc
boosted_trees
count_ops.cccrop_and_resize_op.ccnth_element_op.ccparameterized_truncated_normal_op.ccrandom_binomial_op.ccrandom_op.ccrandom_poisson_op.ccsession_ops.ccsparse_fill_empty_rows_op.ccstateless_random_ops.ccstring_ngrams_op.cctopk_op.cclite
python
@ -248,21 +248,36 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
|
auto tf_dlm_context = GetDlContext(h, status);
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
const Tensor* tensor = GetTensorFromHandle(h, status);
|
const Tensor* tensor = GetTensorFromHandle(h, status);
|
||||||
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
||||||
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
|
|
||||||
|
|
||||||
|
auto tf_dlm_type = GetDlDataType(data_type, status);
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
|
||||||
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
|
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
|
||||||
tf_dlm_tensor_ctx->reference = tensor_ref;
|
tf_dlm_tensor_ctx->reference = tensor_ref;
|
||||||
|
|
||||||
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
|
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
|
||||||
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
|
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
|
||||||
dlm_tensor->deleter = &DLManagedTensorDeleter;
|
dlm_tensor->deleter = &DLManagedTensorDeleter;
|
||||||
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
|
dlm_tensor->dl_tensor.ctx = tf_dlm_context;
|
||||||
int ndim = tensor->dims();
|
int ndim = tensor->dims();
|
||||||
dlm_tensor->dl_tensor.ndim = ndim;
|
dlm_tensor->dl_tensor.ndim = ndim;
|
||||||
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
dlm_tensor->dl_tensor.data = tf_dlm_data;
|
||||||
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
|
dlm_tensor->dl_tensor.dtype = tf_dlm_type;
|
||||||
|
|
||||||
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
||||||
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
||||||
@ -275,13 +290,14 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
|||||||
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
|
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
|
dlm_tensor->dl_tensor.shape = shape_arr->data();
|
||||||
// There are two ways to represent compact row-major data
|
// There are two ways to represent compact row-major data
|
||||||
// 1) nullptr indicates tensor is compact and row-majored.
|
// 1) nullptr indicates tensor is compact and row-majored.
|
||||||
// 2) fill in the strides array as the real case for compact row-major data.
|
// 2) fill in the strides array as the real case for compact row-major data.
|
||||||
// Here we choose option 2, since some frameworks didn't handle the strides
|
// Here we choose option 2, since some frameworks didn't handle the strides
|
||||||
// argument properly.
|
// argument properly.
|
||||||
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
|
dlm_tensor->dl_tensor.strides = stride_arr->data();
|
||||||
|
|
||||||
dlm_tensor->dl_tensor.byte_offset =
|
dlm_tensor->dl_tensor.byte_offset =
|
||||||
0; // TF doesn't handle the strides and byte_offsets here
|
0; // TF doesn't handle the strides and byte_offsets here
|
||||||
return static_cast<void*>(dlm_tensor);
|
return static_cast<void*>(dlm_tensor);
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||||
#include "tensorflow/cc/saved_model/reader.h"
|
#include "tensorflow/cc/saved_model/reader.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/framework/function.proto.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
@ -72,8 +73,7 @@ uint64 GetLatencyMicroseconds(const uint64 start_microseconds) {
|
|||||||
// Ensure that constant tensors loaded from the saved model have valid shape.
|
// Ensure that constant tensors loaded from the saved model have valid shape.
|
||||||
// Also ensure that constant nodes have a value assigned to them.
|
// Also ensure that constant nodes have a value assigned to them.
|
||||||
// TODO(b/154763635): this is temporary and will be replaced with a better audit
|
// TODO(b/154763635): this is temporary and will be replaced with a better audit
|
||||||
static Status ValidateSavedTensors(const GraphDef& graph_def) {
|
static Status ValidateNode(const NodeDef& node) {
|
||||||
for (const auto& node : graph_def.node()) {
|
|
||||||
const auto node_iterator = node.attr().find("value");
|
const auto node_iterator = node.attr().find("value");
|
||||||
if (node_iterator != node.attr().end()) {
|
if (node_iterator != node.attr().end()) {
|
||||||
AttrValue node_value = node_iterator->second;
|
AttrValue node_value = node_iterator->second;
|
||||||
@ -81,8 +81,8 @@ static Status ValidateSavedTensors(const GraphDef& graph_def) {
|
|||||||
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
|
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
|
||||||
if (node_shape.num_elements() < 0) {
|
if (node_shape.num_elements() < 0) {
|
||||||
return errors::FailedPrecondition(
|
return errors::FailedPrecondition(
|
||||||
"Saved model contains node \"", node.name(), "\" (op \"",
|
"Saved model contains node \"", node.name(), "\" (op \"", node.op(),
|
||||||
node.op(), "\") which initializes from a tensor with ",
|
"\") which initializes from a tensor with ",
|
||||||
node_shape.num_elements(), " elements");
|
node_shape.num_elements(), " elements");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -91,7 +91,23 @@ static Status ValidateSavedTensors(const GraphDef& graph_def) {
|
|||||||
"Saved model contains node \"", node.name(),
|
"Saved model contains node \"", node.name(),
|
||||||
"\" which is a constant tensor but no value has been provided");
|
"\" which is a constant tensor but no value has been provided");
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Status ValidateSavedTensors(const GraphDef& graph_def) {
|
||||||
|
for (const auto& node : graph_def.node()) {
|
||||||
|
TF_RETURN_IF_ERROR(ValidateNode(node));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (graph_def.has_library()) {
|
||||||
|
const FunctionDefLibrary& library = graph_def.library();
|
||||||
|
for (const auto& function : library.function()) {
|
||||||
|
for (const auto& node : function.node_def()) {
|
||||||
|
TF_RETURN_IF_ERROR(ValidateNode(node));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -307,7 +307,12 @@ Status KernelAndDeviceOp::Run(
|
|||||||
if (outputs != nullptr) {
|
if (outputs != nullptr) {
|
||||||
outputs->clear();
|
outputs->clear();
|
||||||
for (int i = 0; i < context.num_outputs(); ++i) {
|
for (int i = 0; i < context.num_outputs(); ++i) {
|
||||||
outputs->push_back(Tensor(*context.mutable_output(i)));
|
const auto* output_tensor = context.mutable_output(i);
|
||||||
|
if (output_tensor != nullptr) {
|
||||||
|
outputs->push_back(Tensor(*output_tensor));
|
||||||
|
} else {
|
||||||
|
outputs->push_back(Tensor());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -6085,6 +6085,24 @@ tf_kernel_library(
|
|||||||
deps = STRING_DEPS,
|
deps = STRING_DEPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "as_string_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["as_string_op_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":as_string_op",
|
||||||
|
":ops_testutil",
|
||||||
|
":ops_util",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "unicode_ops",
|
name = "unicode_ops",
|
||||||
prefix = "unicode_ops",
|
prefix = "unicode_ops",
|
||||||
|
@ -65,9 +65,26 @@ class AsStringOp : public OpKernel {
|
|||||||
OP_REQUIRES(ctx, !(scientific && shortest),
|
OP_REQUIRES(ctx, !(scientific && shortest),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Cannot select both scientific and shortest notation"));
|
"Cannot select both scientific and shortest notation"));
|
||||||
|
|
||||||
format_ = "%";
|
format_ = "%";
|
||||||
|
if (!fill_string.empty()) {
|
||||||
|
switch (fill_string[0]) {
|
||||||
|
case ' ':
|
||||||
|
case '+':
|
||||||
|
case '-':
|
||||||
|
case '0':
|
||||||
|
case '#':
|
||||||
|
strings::Appendf(&format_, "%s", fill_string.c_str());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
bool fill_not_supported = true;
|
||||||
|
OP_REQUIRES(ctx, !fill_not_supported,
|
||||||
|
errors::InvalidArgument("Fill argument not supported: \"",
|
||||||
|
fill_string, "\""));
|
||||||
|
}
|
||||||
|
}
|
||||||
if (width > -1) {
|
if (width > -1) {
|
||||||
strings::Appendf(&format_, "%s%d", fill_string.c_str(), width);
|
strings::Appendf(&format_, "%d", width);
|
||||||
}
|
}
|
||||||
if (precision > -1) {
|
if (precision > -1) {
|
||||||
strings::Appendf(&format_, ".%d", precision);
|
strings::Appendf(&format_, ".%d", precision);
|
||||||
|
245
tensorflow/core/kernels/as_string_op_test.cc
Normal file
245
tensorflow/core/kernels/as_string_op_test.cc
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/fake_input.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class AsStringGraphTest : public OpsTestBase {
|
||||||
|
protected:
|
||||||
|
Status Init(DataType input_type, const string& fill = "", int width = -1,
|
||||||
|
int precision = -1, bool scientific = false,
|
||||||
|
bool shortest = false) {
|
||||||
|
TF_CHECK_OK(NodeDefBuilder("op", "AsString")
|
||||||
|
.Input(FakeInput(input_type))
|
||||||
|
.Attr("fill", fill)
|
||||||
|
.Attr("precision", precision)
|
||||||
|
.Attr("scientific", scientific)
|
||||||
|
.Attr("shortest", shortest)
|
||||||
|
.Attr("width", width)
|
||||||
|
.Finalize(node_def()));
|
||||||
|
return InitOp();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, Int8) {
|
||||||
|
TF_ASSERT_OK(Init(DT_INT8));
|
||||||
|
|
||||||
|
AddInputFromArray<int8>(TensorShape({3}), {-42, 0, 42});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
|
||||||
|
test::FillValues<tstring>(&expected, {"-42", "0", "42"});
|
||||||
|
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, Int64) {
|
||||||
|
TF_ASSERT_OK(Init(DT_INT64));
|
||||||
|
|
||||||
|
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
|
||||||
|
test::FillValues<tstring>(&expected, {"-42", "0", "42"});
|
||||||
|
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, FloatDefault) {
|
||||||
|
TF_ASSERT_OK(Init(DT_FLOAT));
|
||||||
|
|
||||||
|
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
|
||||||
|
test::FillValues<tstring>(
|
||||||
|
&expected, {"-42.000000", "0.000000", "3.141590", "42.000000"});
|
||||||
|
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, FloatScientific) {
|
||||||
|
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
|
||||||
|
/*scientific=*/true));
|
||||||
|
|
||||||
|
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
|
||||||
|
test::FillValues<tstring>(&expected, {"-4.200000e+01", "0.000000e+00",
|
||||||
|
"3.141590e+00", "4.200000e+01"});
|
||||||
|
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, FloatShortest) {
|
||||||
|
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
|
||||||
|
/*scientific=*/false, /*shortest=*/true));
|
||||||
|
|
||||||
|
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
|
||||||
|
test::FillValues<tstring>(&expected, {"-42", "0", "3.14159", "42"});
|
||||||
|
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, FloatPrecisionOnly) {
|
||||||
|
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/2));
|
||||||
|
|
||||||
|
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
|
||||||
|
test::FillValues<tstring>(&expected, {"-42.00", "0.00", "3.14", "42.00"});
|
||||||
|
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, FloatWidthOnly) {
|
||||||
|
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/5));
|
||||||
|
|
||||||
|
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
|
||||||
|
test::FillValues<tstring>(
|
||||||
|
&expected, {"-42.000000", "0.000000", "3.141590", "42.000000"});
|
||||||
|
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, Float_5_2_Format) {
|
||||||
|
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/5, /*precision=*/2));
|
||||||
|
|
||||||
|
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
|
||||||
|
test::FillValues<tstring>(&expected, {"-42.00", " 0.00", " 3.14", "42.00"});
|
||||||
|
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, Complex) {
|
||||||
|
TF_ASSERT_OK(Init(DT_COMPLEX64, /*fill=*/"", /*width=*/5, /*precision=*/2));
|
||||||
|
|
||||||
|
AddInputFromArray<complex64>(TensorShape({3}), {{-4, 2}, {0}, {3.14159, -1}});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
|
||||||
|
test::FillValues<tstring>(
|
||||||
|
&expected, {"(-4.00, 2.00)", "( 0.00, 0.00)", "( 3.14,-1.00)"});
|
||||||
|
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, Bool) {
|
||||||
|
TF_ASSERT_OK(Init(DT_BOOL));
|
||||||
|
|
||||||
|
AddInputFromArray<bool>(TensorShape({2}), {true, false});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({2}));
|
||||||
|
test::FillValues<tstring>(&expected, {"true", "false"});
|
||||||
|
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, String) {
|
||||||
|
Status s = Init(DT_STRING);
|
||||||
|
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||||
|
ASSERT_TRUE(absl::StrContains(
|
||||||
|
s.error_message(),
|
||||||
|
"Value for attr 'T' of string is not in the list of allowed values"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, OnlyOneOfScientificAndShortest) {
|
||||||
|
Status s = Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
|
||||||
|
/*scientific=*/true, /*shortest=*/true);
|
||||||
|
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||||
|
ASSERT_TRUE(
|
||||||
|
absl::StrContains(s.error_message(),
|
||||||
|
"Cannot select both scientific and shortest notation"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, NoShortestForNonFloat) {
|
||||||
|
Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
|
||||||
|
/*scientific=*/false, /*shortest=*/true);
|
||||||
|
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||||
|
ASSERT_TRUE(absl::StrContains(
|
||||||
|
s.error_message(),
|
||||||
|
"scientific and shortest format not supported for datatype"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, NoScientificForNonFloat) {
|
||||||
|
Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
|
||||||
|
/*scientific=*/true);
|
||||||
|
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||||
|
ASSERT_TRUE(absl::StrContains(
|
||||||
|
s.error_message(),
|
||||||
|
"scientific and shortest format not supported for datatype"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, NoPrecisionForNonFloat) {
|
||||||
|
Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/5);
|
||||||
|
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||||
|
ASSERT_TRUE(absl::StrContains(s.error_message(),
|
||||||
|
"precision not supported for datatype"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, LongFill) {
|
||||||
|
Status s = Init(DT_INT32, /*fill=*/"asdf");
|
||||||
|
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||||
|
ASSERT_TRUE(absl::StrContains(s.error_message(),
|
||||||
|
"Fill string must be one or fewer characters"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, FillWithZero) {
|
||||||
|
TF_ASSERT_OK(Init(DT_INT64, /*fill=*/"0", /*width=*/4));
|
||||||
|
|
||||||
|
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
|
||||||
|
test::FillValues<tstring>(&expected, {"-042", "0000", "0042"});
|
||||||
|
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, FillWithSpace) {
|
||||||
|
TF_ASSERT_OK(Init(DT_INT64, /*fill=*/" ", /*width=*/4));
|
||||||
|
|
||||||
|
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
|
||||||
|
test::FillValues<tstring>(&expected, {" -42", " 0", " 42"});
|
||||||
|
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, FillWithChar1) {
|
||||||
|
TF_ASSERT_OK(Init(DT_INT64, /*fill=*/"-", /*width=*/4));
|
||||||
|
|
||||||
|
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
|
||||||
|
test::FillValues<tstring>(&expected, {"-42 ", "0 ", "42 "});
|
||||||
|
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, FillWithChar3) {
|
||||||
|
Status s = Init(DT_INT32, /*fill=*/"s");
|
||||||
|
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||||
|
ASSERT_TRUE(
|
||||||
|
absl::StrContains(s.error_message(), "Fill argument not supported"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AsStringGraphTest, FillWithChar4) {
|
||||||
|
Status s = Init(DT_INT32, /*fill=*/"n");
|
||||||
|
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||||
|
ASSERT_TRUE(
|
||||||
|
absl::StrContains(s.error_message(), "Fill argument not supported"));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace
|
||||||
|
} // end namespace tensorflow
|
@ -193,7 +193,8 @@ struct LaunchBatchBandedTriangularSolve {
|
|||||||
|
|
||||||
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
||||||
cost_per_unit,
|
cost_per_unit,
|
||||||
[&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) {
|
[&in_x, &in_y, adjoint, lower, &bcast, out](int64 start,
|
||||||
|
int64 limit) {
|
||||||
SequentialBandedTriangularSolveKernel<Scalar>::Run(
|
SequentialBandedTriangularSolveKernel<Scalar>::Run(
|
||||||
in_x, in_y, lower, adjoint, bcast, out, start, limit);
|
in_x, in_y, lower, adjoint, bcast, out, start, limit);
|
||||||
});
|
});
|
||||||
|
@ -121,7 +121,7 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
|
|||||||
auto do_work = [&resource, &bucketized_features, &cached_tree_ids,
|
auto do_work = [&resource, &bucketized_features, &cached_tree_ids,
|
||||||
&cached_node_ids, &output_partial_logits,
|
&cached_node_ids, &output_partial_logits,
|
||||||
&output_node_ids, latest_tree,
|
&output_node_ids, latest_tree,
|
||||||
this](int32 start, int32 end) {
|
this](int64 start, int64 end) {
|
||||||
for (int32 i = start; i < end; ++i) {
|
for (int32 i = start; i < end; ++i) {
|
||||||
int32 tree_id = cached_tree_ids(i);
|
int32 tree_id = cached_tree_ids(i);
|
||||||
int32 node_id = cached_node_ids(i);
|
int32 node_id = cached_node_ids(i);
|
||||||
@ -237,7 +237,7 @@ class BoostedTreesPredictOp : public OpKernel {
|
|||||||
|
|
||||||
const int32 last_tree = resource->num_trees() - 1;
|
const int32 last_tree = resource->num_trees() - 1;
|
||||||
auto do_work = [&resource, &bucketized_features, &output_logits, last_tree,
|
auto do_work = [&resource, &bucketized_features, &output_logits, last_tree,
|
||||||
this](int32 start, int32 end) {
|
this](int64 start, int64 end) {
|
||||||
for (int32 i = start; i < end; ++i) {
|
for (int32 i = start; i < end; ++i) {
|
||||||
std::vector<float> tree_logits(logits_dimension_, 0.0);
|
std::vector<float> tree_logits(logits_dimension_, 0.0);
|
||||||
int32 tree_id = 0;
|
int32 tree_id = 0;
|
||||||
@ -340,7 +340,7 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel {
|
|||||||
// path. Note: feature_ids has one less value than logits_path because the
|
// path. Note: feature_ids has one less value than logits_path because the
|
||||||
// first value of each logit path will be the bias.
|
// first value of each logit path will be the bias.
|
||||||
auto do_work = [&resource, &bucketized_features, &output_debug_info,
|
auto do_work = [&resource, &bucketized_features, &output_debug_info,
|
||||||
last_tree](int32 start, int32 end) {
|
last_tree](int64 start, int64 end) {
|
||||||
for (int32 i = start; i < end; ++i) {
|
for (int32 i = start; i < end; ++i) {
|
||||||
// Proto to store debug outputs, per example.
|
// Proto to store debug outputs, per example.
|
||||||
boosted_trees::DebugOutput example_debug_info;
|
boosted_trees::DebugOutput example_debug_info;
|
||||||
|
@ -178,10 +178,30 @@ class SparseCount : public OpKernel {
|
|||||||
const Tensor& weights = context->input(3);
|
const Tensor& weights = context->input(3);
|
||||||
bool use_weights = weights.NumElements() > 0;
|
bool use_weights = weights.NumElements() > 0;
|
||||||
|
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Input indices must be a 2-dimensional tensor. Got: ",
|
||||||
|
indices.shape().DebugString()));
|
||||||
|
|
||||||
|
if (use_weights) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, weights.shape() == values.shape(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Weights and values must have the same shape. Weight shape: ",
|
||||||
|
weights.shape().DebugString(),
|
||||||
|
"; values shape: ", values.shape().DebugString()));
|
||||||
|
}
|
||||||
|
|
||||||
bool is_1d = shape.NumElements() == 1;
|
bool is_1d = shape.NumElements() == 1;
|
||||||
int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
|
int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
|
||||||
int num_values = values.NumElements();
|
int num_values = values.NumElements();
|
||||||
|
|
||||||
|
OP_REQUIRES(context, num_values == indices.shape().dim_size(0),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Number of values must match first dimension of indices.",
|
||||||
|
"Got ", num_values,
|
||||||
|
" values, indices shape: ", indices.shape().DebugString()));
|
||||||
|
|
||||||
const auto indices_values = indices.matrix<int64>();
|
const auto indices_values = indices.matrix<int64>();
|
||||||
const auto values_values = values.flat<T>();
|
const auto values_values = values.flat<T>();
|
||||||
const auto weight_values = weights.flat<W>();
|
const auto weight_values = weights.flat<W>();
|
||||||
@ -235,12 +255,33 @@ class RaggedCount : public OpKernel {
|
|||||||
bool use_weights = weights.NumElements() > 0;
|
bool use_weights = weights.NumElements() > 0;
|
||||||
bool is_1d = false;
|
bool is_1d = false;
|
||||||
|
|
||||||
|
if (use_weights) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, weights.shape() == values.shape(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Weights and values must have the same shape. Weight shape: ",
|
||||||
|
weights.shape().DebugString(),
|
||||||
|
"; values shape: ", values.shape().DebugString()));
|
||||||
|
}
|
||||||
|
|
||||||
const auto splits_values = splits.flat<int64>();
|
const auto splits_values = splits.flat<int64>();
|
||||||
const auto values_values = values.flat<T>();
|
const auto values_values = values.flat<T>();
|
||||||
const auto weight_values = weights.flat<W>();
|
const auto weight_values = weights.flat<W>();
|
||||||
int num_batches = splits.NumElements() - 1;
|
int num_batches = splits.NumElements() - 1;
|
||||||
int num_values = values.NumElements();
|
int num_values = values.NumElements();
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, num_batches > 0,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Must provide at least 2 elements for the splits argument"));
|
||||||
|
OP_REQUIRES(context, splits_values(0) == 0,
|
||||||
|
errors::InvalidArgument("Splits must start with 0, not with ",
|
||||||
|
splits_values(0)));
|
||||||
|
OP_REQUIRES(context, splits_values(num_batches) == num_values,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Splits must end with the number of values, got ",
|
||||||
|
splits_values(num_batches), " instead of ", num_values));
|
||||||
|
|
||||||
auto per_batch_counts = BatchedMap<W>(num_batches);
|
auto per_batch_counts = BatchedMap<W>(num_batches);
|
||||||
T max_value = 0;
|
T max_value = 0;
|
||||||
int batch_idx = 0;
|
int batch_idx = 0;
|
||||||
|
@ -223,7 +223,7 @@ struct CropAndResize<CPUDevice, T> {
|
|||||||
const int depth = crops.dimension(3);
|
const int depth = crops.dimension(3);
|
||||||
|
|
||||||
// Sharding across boxes.
|
// Sharding across boxes.
|
||||||
auto CropAndResizePerBox = [&](int start_box, int limit_box) {
|
auto CropAndResizePerBox = [&](int64 start_box, int64 limit_box) {
|
||||||
for (int b = start_box; b < limit_box; ++b) {
|
for (int b = start_box; b < limit_box; ++b) {
|
||||||
const float y1 = boxes(b, 0);
|
const float y1 = boxes(b, 0);
|
||||||
const float x1 = boxes(b, 1);
|
const float x1 = boxes(b, 1);
|
||||||
@ -449,7 +449,7 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
|
|||||||
|
|
||||||
grads_image.setZero();
|
grads_image.setZero();
|
||||||
|
|
||||||
auto CropAndResizeBackImgPerBox = [&](int start_box, int limit_box) {
|
auto CropAndResizeBackImgPerBox = [&](int64 start_box, int64 limit_box) {
|
||||||
for (int b = start_box; b < limit_box; ++b) {
|
for (int b = start_box; b < limit_box; ++b) {
|
||||||
const float y1 = boxes(b, 0);
|
const float y1 = boxes(b, 0);
|
||||||
const float x1 = boxes(b, 1);
|
const float x1 = boxes(b, 1);
|
||||||
|
@ -95,7 +95,8 @@ struct NthElementFunctor<CPUDevice, T> {
|
|||||||
const int last_dim = input_tensor.dim_size(input_tensor.dims() - 1);
|
const int last_dim = input_tensor.dim_size(input_tensor.dims() - 1);
|
||||||
|
|
||||||
// Allocate each row to different shard.
|
// Allocate each row to different shard.
|
||||||
auto SubNthElement = [&, input, output, last_dim, n](int start, int limit) {
|
auto SubNthElement = [&, input, output, last_dim, n](int64 start,
|
||||||
|
int64 limit) {
|
||||||
// std::nth_element would rearrange the array, so we need a new buffer.
|
// std::nth_element would rearrange the array, so we need a new buffer.
|
||||||
std::vector<T> buf(last_dim);
|
std::vector<T> buf(last_dim);
|
||||||
|
|
||||||
|
@ -70,8 +70,8 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
|
|
||||||
auto do_work = [samples_per_batch, num_elements, &ctx, &means, &stddevs,
|
auto do_work = [samples_per_batch, num_elements, &ctx, &means, &stddevs,
|
||||||
&minvals, &maxvals, &gen, &output,
|
&minvals, &maxvals, &gen, &output,
|
||||||
kStdDevsInsideBoundsToUseRandnSampler](int start_batch,
|
kStdDevsInsideBoundsToUseRandnSampler](int64 start_batch,
|
||||||
int limit_batch) {
|
int64 limit_batch) {
|
||||||
// Capturing "gen" by-value would only make a copy for the _shared_
|
// Capturing "gen" by-value would only make a copy for the _shared_
|
||||||
// lambda. Since we want to let each worker have its own copy, we pass
|
// lambda. Since we want to let each worker have its own copy, we pass
|
||||||
// "gen" by reference and explicitly do a copy assignment here.
|
// "gen" by reference and explicitly do a copy assignment here.
|
||||||
@ -333,8 +333,8 @@ struct TruncatedNormalFunctorV2<CPUDevice, T> {
|
|||||||
|
|
||||||
auto do_work = [num_batches, samples_per_batch, &ctx, &bcast, &means,
|
auto do_work = [num_batches, samples_per_batch, &ctx, &bcast, &means,
|
||||||
&stddevs, &minvals, &maxvals, &gen, &output,
|
&stddevs, &minvals, &maxvals, &gen, &output,
|
||||||
kStdDevsInsideBoundsToUseRandnSampler](int start_output,
|
kStdDevsInsideBoundsToUseRandnSampler](int64 start_output,
|
||||||
int limit_output) {
|
int64 limit_output) {
|
||||||
// Capturing "gen" by-value would only make a copy for the _shared_
|
// Capturing "gen" by-value would only make a copy for the _shared_
|
||||||
// lambda. Since we want to let each worker have its own copy, we pass
|
// lambda. Since we want to let each worker have its own copy, we pass
|
||||||
// "gen" by reference and explicitly do a copy assignment here.
|
// "gen" by reference and explicitly do a copy assignment here.
|
||||||
|
@ -182,7 +182,7 @@ struct RandomBinomialFunctor<CPUDevice, T, U> {
|
|||||||
// the sample shape and [H1, ... Hm] for the batch shape of the samples.
|
// the sample shape and [H1, ... Hm] for the batch shape of the samples.
|
||||||
// We have B1 * ... * Bk samples per batch member we need.
|
// We have B1 * ... * Bk samples per batch member we need.
|
||||||
auto DoWork = [num_batches, samples_per_batch, &bcast, &counts, &probs,
|
auto DoWork = [num_batches, samples_per_batch, &bcast, &counts, &probs,
|
||||||
&gen, &output](int start_output, int limit_output) {
|
&gen, &output](int64 start_output, int64 limit_output) {
|
||||||
// Vectorized intermediate calculations for uniform rejection sampling.
|
// Vectorized intermediate calculations for uniform rejection sampling.
|
||||||
// We always generate at most 4 samples.
|
// We always generate at most 4 samples.
|
||||||
Eigen::array<T, 4> z;
|
Eigen::array<T, 4> z;
|
||||||
|
@ -205,7 +205,7 @@ class RandomGammaOp : public OpKernel {
|
|||||||
// avoid a couple flops which can be done on a per-alpha basis.
|
// avoid a couple flops which can be done on a per-alpha basis.
|
||||||
|
|
||||||
auto DoWork = [samples_per_alpha, num_alphas, &rng, samples_flat,
|
auto DoWork = [samples_per_alpha, num_alphas, &rng, samples_flat,
|
||||||
alpha_flat](int start_output, int limit_output) {
|
alpha_flat](int64 start_output, int64 limit_output) {
|
||||||
using Eigen::numext::exp;
|
using Eigen::numext::exp;
|
||||||
using Eigen::numext::log;
|
using Eigen::numext::log;
|
||||||
using Eigen::numext::log1p;
|
using Eigen::numext::log1p;
|
||||||
|
@ -97,7 +97,7 @@ struct PoissonFunctor<CPUDevice, T, U> {
|
|||||||
typedef random::UniformDistribution<random::PhiloxRandom, CT> Uniform;
|
typedef random::UniformDistribution<random::PhiloxRandom, CT> Uniform;
|
||||||
|
|
||||||
auto DoWork = [num_samples, num_rate, &rng, samples_flat, rate_flat](
|
auto DoWork = [num_samples, num_rate, &rng, samples_flat, rate_flat](
|
||||||
int start_output, int limit_output) {
|
int64 start_output, int64 limit_output) {
|
||||||
// Capturing "rng" by value would only make a copy for the _shared_
|
// Capturing "rng" by value would only make a copy for the _shared_
|
||||||
// lambda. Since we want to let each worker have its own copy, we pass
|
// lambda. Since we want to let each worker have its own copy, we pass
|
||||||
// "rng" by reference and explicitly do a copy assignment.
|
// "rng" by reference and explicitly do a copy assignment.
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
// See docs in ../ops/data_flow_ops.cc.
|
// See docs in ../ops/data_flow_ops.cc.
|
||||||
|
|
||||||
#include <limits.h>
|
#include <limits.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
@ -27,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
@ -42,7 +44,11 @@ class GetSessionHandleOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
const Tensor& val = ctx->input(0);
|
const Tensor& val = ctx->input(0);
|
||||||
int64 id = ctx->session_state()->GetNewId();
|
auto session_state = ctx->session_state();
|
||||||
|
OP_REQUIRES(ctx, session_state != nullptr,
|
||||||
|
errors::FailedPrecondition(
|
||||||
|
"GetSessionHandle called on null session state"));
|
||||||
|
int64 id = session_state->GetNewId();
|
||||||
TensorStore::TensorAndKey tk{val, id, requested_device()};
|
TensorStore::TensorAndKey tk{val, id, requested_device()};
|
||||||
OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk));
|
OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk));
|
||||||
|
|
||||||
|
@ -232,6 +232,9 @@ class SparseFillEmptyRowsGradOp : public OpKernel {
|
|||||||
context, TensorShapeUtils::IsVector(reverse_index_map_t->shape()),
|
context, TensorShapeUtils::IsVector(reverse_index_map_t->shape()),
|
||||||
errors::InvalidArgument("reverse_index_map must be a vector, saw: ",
|
errors::InvalidArgument("reverse_index_map must be a vector, saw: ",
|
||||||
reverse_index_map_t->shape().DebugString()));
|
reverse_index_map_t->shape().DebugString()));
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsVector(grad_values_t->shape()),
|
||||||
|
errors::InvalidArgument("grad_values must be a vector, saw: ",
|
||||||
|
grad_values_t->shape().DebugString()));
|
||||||
|
|
||||||
const auto reverse_index_map = reverse_index_map_t->vec<int64>();
|
const auto reverse_index_map = reverse_index_map_t->vec<int64>();
|
||||||
const auto grad_values = grad_values_t->vec<T>();
|
const auto grad_values = grad_values_t->vec<T>();
|
||||||
@ -260,8 +263,13 @@ class SparseFillEmptyRowsGradOp : public OpKernel {
|
|||||||
// Locate the index of the output of the forward prop associated
|
// Locate the index of the output of the forward prop associated
|
||||||
// with this location in the input of the forward prop. Copy
|
// with this location in the input of the forward prop. Copy
|
||||||
// the gradient into it. Mark it as visited.
|
// the gradient into it. Mark it as visited.
|
||||||
d_values(i) = grad_values(reverse_index_map(i));
|
int64 reverse_index = reverse_index_map(i);
|
||||||
visited(reverse_index_map(i)) = true;
|
OP_REQUIRES(
|
||||||
|
context, 0 <= reverse_index && reverse_index < N_full,
|
||||||
|
errors::InvalidArgument("Elements in reverse index must be in [0, ",
|
||||||
|
N_full, ") but got ", reverse_index));
|
||||||
|
d_values(i) = grad_values(reverse_index);
|
||||||
|
visited(reverse_index) = true;
|
||||||
}
|
}
|
||||||
for (int j = 0; j < N_full; ++j) {
|
for (int j = 0; j < N_full; ++j) {
|
||||||
// The default value gradient gets the accumulated remainder of
|
// The default value gradient gets the accumulated remainder of
|
||||||
|
@ -252,7 +252,7 @@ class StatelessRandomGammaOp : public StatelessRandomOpBase {
|
|||||||
// avoid a couple flops which can be done on a per-alpha basis.
|
// avoid a couple flops which can be done on a per-alpha basis.
|
||||||
|
|
||||||
auto DoWork = [samples_per_alpha, num_alphas, &random, samples_flat,
|
auto DoWork = [samples_per_alpha, num_alphas, &random, samples_flat,
|
||||||
alpha_flat](int start_output, int limit_output) {
|
alpha_flat](int64 start_output, int64 limit_output) {
|
||||||
// Capturing "random" by-value would only make a copy for the _shared_
|
// Capturing "random" by-value would only make a copy for the _shared_
|
||||||
// lambda. Since we want to let each worker have its own copy, we pass
|
// lambda. Since we want to let each worker have its own copy, we pass
|
||||||
// "random" by reference and explicitly do a copy assignment.
|
// "random" by reference and explicitly do a copy assignment.
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "absl/strings/ascii.h"
|
#include "absl/strings/ascii.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace text {
|
namespace text {
|
||||||
@ -60,6 +61,18 @@ class StringNGramsOp : public tensorflow::OpKernel {
|
|||||||
OP_REQUIRES_OK(context, context->input("data_splits", &splits));
|
OP_REQUIRES_OK(context, context->input("data_splits", &splits));
|
||||||
const auto& splits_vec = splits->flat<SPLITS_TYPE>();
|
const auto& splits_vec = splits->flat<SPLITS_TYPE>();
|
||||||
|
|
||||||
|
// Validate that the splits are valid indices into data
|
||||||
|
const int input_data_size = data->flat<tstring>().size();
|
||||||
|
const int splits_vec_size = splits_vec.size();
|
||||||
|
for (int i = 0; i < splits_vec_size; ++i) {
|
||||||
|
bool valid_splits = splits_vec(i) >= 0;
|
||||||
|
valid_splits = valid_splits && (splits_vec(i) <= input_data_size);
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, valid_splits,
|
||||||
|
errors::InvalidArgument("Invalid split value ", splits_vec(i),
|
||||||
|
", must be in [0,", input_data_size, "]"));
|
||||||
|
}
|
||||||
|
|
||||||
int num_batch_items = splits_vec.size() - 1;
|
int num_batch_items = splits_vec.size() - 1;
|
||||||
tensorflow::Tensor* ngrams_splits;
|
tensorflow::Tensor* ngrams_splits;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
|
@ -136,7 +136,7 @@ struct TopKFunctor<CPUDevice, T> {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto SortIndices = [&](int start_batch, int limit_batch) {
|
auto SortIndices = [&](int64 start_batch, int64 limit_batch) {
|
||||||
for (int32 b = start_batch; b < limit_batch; ++b) {
|
for (int32 b = start_batch; b < limit_batch; ++b) {
|
||||||
const T* input_data = &input(b, 0);
|
const T* input_data = &input(b, 0);
|
||||||
const auto stable_comp = [input_data](const int32 a, const int32 b) {
|
const auto stable_comp = [input_data](const int32 a, const int32 b) {
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "tensorflow/lite/arena_planner.h"
|
#include "tensorflow/lite/arena_planner.h"
|
||||||
|
#include "tensorflow/lite/builtin_ops.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/context_util.h"
|
#include "tensorflow/lite/context_util.h"
|
||||||
#include "tensorflow/lite/core/api/tensor_utils.h"
|
#include "tensorflow/lite/core/api/tensor_utils.h"
|
||||||
@ -567,6 +568,33 @@ TfLiteStatus Subgraph::CheckTensorIndices(const char* label, const int* indices,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We have two arrays and we need to check that elements from one array don't
|
||||||
|
// show up in the other. We could sort both arrays and then iterate with two
|
||||||
|
// pointers from start to finish always increasing the smaller one but since
|
||||||
|
// these arrays are usually short (<25 elements for inputs, usually <3 for
|
||||||
|
// outputs), this might be slower than the naive approach (if arrays have size n
|
||||||
|
// and m, with n >> m ~ O(1), first approach is O(nlogn) whereas the other is
|
||||||
|
// O(n)). Plus, sorting the input and output arrays might not be something we
|
||||||
|
// want as it destroys ordering of elements.
|
||||||
|
//
|
||||||
|
// If it turns out that this is an issue, we can switch to the other algorithm.
|
||||||
|
TfLiteStatus Subgraph::CheckInputAndOutputForOverlap(const int* input_indices,
|
||||||
|
int num_inputs,
|
||||||
|
const int* output_indices,
|
||||||
|
int num_outputs) {
|
||||||
|
for (int i = 0; i < num_inputs; i++) {
|
||||||
|
for (int j = 0; j < num_outputs; j++) {
|
||||||
|
if (input_indices[i] == output_indices[j]) {
|
||||||
|
ReportError("Tensor %d is both input %d and output %d\n",
|
||||||
|
input_indices[i], i, j);
|
||||||
|
consistent_ = false;
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Multiply two sizes and return true if overflow occurred;
|
// Multiply two sizes and return true if overflow occurred;
|
||||||
// This is based off tensorflow/overflow.h but is simpler as we already
|
// This is based off tensorflow/overflow.h but is simpler as we already
|
||||||
@ -688,6 +716,16 @@ TfLiteStatus Subgraph::AddNodeWithParameters(
|
|||||||
&context_,
|
&context_,
|
||||||
CheckTensorIndices("node outputs", outputs.data(), outputs.size()));
|
CheckTensorIndices("node outputs", outputs.data(), outputs.size()));
|
||||||
|
|
||||||
|
// For builtin ops, inputs and outputs must not overlap. Custom ops must do
|
||||||
|
// this check by themselves if they don't support overlapping tensors. This
|
||||||
|
// distinction is to allow custom ops to just forward a tensor, reusing it as
|
||||||
|
// both input and output.
|
||||||
|
if (builtin_data != nullptr) {
|
||||||
|
TF_LITE_ENSURE_OK(&context_, CheckInputAndOutputForOverlap(
|
||||||
|
inputs.data(), inputs.size(),
|
||||||
|
outputs.data(), outputs.size()));
|
||||||
|
}
|
||||||
|
|
||||||
int new_node_index = nodes_and_registration_.size();
|
int new_node_index = nodes_and_registration_.size();
|
||||||
if (node_index) *node_index = new_node_index;
|
if (node_index) *node_index = new_node_index;
|
||||||
nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
|
nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
|
||||||
@ -934,6 +972,19 @@ TfLiteStatus Subgraph::Invoke() {
|
|||||||
tensor->data_is_stale) {
|
tensor->data_is_stale) {
|
||||||
TF_LITE_ENSURE_STATUS(EnsureTensorDataIsReadable(tensor_index));
|
TF_LITE_ENSURE_STATUS(EnsureTensorDataIsReadable(tensor_index));
|
||||||
}
|
}
|
||||||
|
if (tensor->data.raw == nullptr && tensor->bytes > 0) {
|
||||||
|
if (registration.builtin_code == kTfLiteBuiltinReshape && i == 1) {
|
||||||
|
// In general, having a tensor here with no buffer will be an error.
|
||||||
|
// However, for the reshape operator, the second input tensor is only
|
||||||
|
// used for the shape, not for the data. Thus, null buffer is ok.
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
// In all other cases, we need to return an error as otherwise we will
|
||||||
|
// trigger a null pointer dereference (likely).
|
||||||
|
ReportError("Input tensor %d lacks data", tensor_index);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (check_cancelled_func_ != nullptr &&
|
if (check_cancelled_func_ != nullptr &&
|
||||||
|
@ -433,6 +433,15 @@ class Subgraph {
|
|||||||
TfLiteStatus CheckTensorIndices(const char* label, const int* indices,
|
TfLiteStatus CheckTensorIndices(const char* label, const int* indices,
|
||||||
int length);
|
int length);
|
||||||
|
|
||||||
|
// Check that the input indices and the output indices don't overlap.
|
||||||
|
// This is needed because same tensor must not be used both as input and
|
||||||
|
// output for an operator.
|
||||||
|
// NOTE: this changes consistent_ to be false if indices are out of bounds.
|
||||||
|
TfLiteStatus CheckInputAndOutputForOverlap(const int* input_indices,
|
||||||
|
int num_inputs,
|
||||||
|
const int* output_indices,
|
||||||
|
int num_outputs);
|
||||||
|
|
||||||
// Compute the number of bytes required to represent a tensor with dimensions
|
// Compute the number of bytes required to represent a tensor with dimensions
|
||||||
// specified by the array dims (of length dims_size). Returns the status code
|
// specified by the array dims (of length dims_size). Returns the status code
|
||||||
// and bytes.
|
// and bytes.
|
||||||
|
@ -609,7 +609,12 @@ TfLiteStatus InterpreterBuilder::operator()(
|
|||||||
auto* buffers = model_->buffers();
|
auto* buffers = model_->buffers();
|
||||||
|
|
||||||
if (subgraphs->size() == 0) {
|
if (subgraphs->size() == 0) {
|
||||||
error_reporter_->Report("No subgraph in the model.\n");
|
TF_LITE_REPORT_ERROR(error_reporter_, "No subgraph in the model.\n");
|
||||||
|
return cleanup_and_error();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!buffers) {
|
||||||
|
TF_LITE_REPORT_ERROR(error_reporter_, "No buffers in the model.\n");
|
||||||
return cleanup_and_error();
|
return cleanup_and_error();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -630,9 +635,9 @@ TfLiteStatus InterpreterBuilder::operator()(
|
|||||||
(*interpreter)->subgraph(subgraph_index);
|
(*interpreter)->subgraph(subgraph_index);
|
||||||
auto operators = subgraph->operators();
|
auto operators = subgraph->operators();
|
||||||
auto tensors = subgraph->tensors();
|
auto tensors = subgraph->tensors();
|
||||||
if (!operators || !tensors || !buffers) {
|
if (!operators || !tensors) {
|
||||||
error_reporter_->Report(
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||||
"Did not get operators, tensors, or buffers in subgraph %d.\n",
|
"Did not get operators or tensors in subgraph %d.\n",
|
||||||
subgraph_index);
|
subgraph_index);
|
||||||
return cleanup_and_error();
|
return cleanup_and_error();
|
||||||
}
|
}
|
||||||
|
@ -70,6 +70,9 @@ inline bool ResolveAxis(const int num_dims, const int* axis,
|
|||||||
// eg: For num_dims=3, [0, 1, 2] is the same as [-3, -2, -1] */
|
// eg: For num_dims=3, [0, 1, 2] is the same as [-3, -2, -1] */
|
||||||
int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
|
int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
|
||||||
TFLITE_DCHECK(current >= 0 && current < num_dims);
|
TFLITE_DCHECK(current >= 0 && current < num_dims);
|
||||||
|
if (current < 0 || current >= num_dims) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
bool is_dup = false;
|
bool is_dup = false;
|
||||||
for (int j = 0; j < *out_num_axis; ++j) {
|
for (int j = 0; j < *out_num_axis; ++j) {
|
||||||
if (out_axis[j] == current) {
|
if (out_axis[j] == current) {
|
||||||
|
@ -432,7 +432,7 @@ int MatchingArraySize(const ArrayType1& array1, int index1,
|
|||||||
inline int MatchingDim(const RuntimeShape& shape1, int index1,
|
inline int MatchingDim(const RuntimeShape& shape1, int index1,
|
||||||
const RuntimeShape& shape2, int index2) {
|
const RuntimeShape& shape2, int index2) {
|
||||||
TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
|
TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
|
||||||
return shape1.Dims(index1);
|
return std::min(shape1.Dims(index1), shape2.Dims(index2));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
|
@ -30,27 +30,48 @@ inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
|
|||||||
}
|
}
|
||||||
inline const TfLiteTensor* GetInput(const TfLiteContext* context,
|
inline const TfLiteTensor* GetInput(const TfLiteContext* context,
|
||||||
const TfLiteNode* node, int index) {
|
const TfLiteNode* node, int index) {
|
||||||
return &context->tensors[node->inputs->data[index]];
|
const int tensor_index = node->inputs->data[index];
|
||||||
|
if (tensor_index < 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return &context->tensors[tensor_index];
|
||||||
}
|
}
|
||||||
// Note: You must check if result is not null:
|
// Note: You must check if result is not null:
|
||||||
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
|
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
|
||||||
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
||||||
inline TfLiteTensor* GetVariableInput(TfLiteContext* context,
|
inline TfLiteTensor* GetVariableInput(TfLiteContext* context,
|
||||||
const TfLiteNode* node, int index) {
|
const TfLiteNode* node, int index) {
|
||||||
TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
|
const int tensor_index = node->inputs->data[index];
|
||||||
|
if (tensor_index < 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
TfLiteTensor* tensor = &context->tensors[tensor_index];
|
||||||
return (tensor->is_variable) ? tensor : nullptr;
|
return (tensor->is_variable) ? tensor : nullptr;
|
||||||
}
|
}
|
||||||
inline TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
|
inline TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
|
||||||
int index) {
|
int index) {
|
||||||
return &context->tensors[node->outputs->data[index]];
|
const int tensor_index = node->outputs->data[index];
|
||||||
|
if (tensor_index < 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return &context->tensors[tensor_index];
|
||||||
}
|
}
|
||||||
inline TfLiteTensor* GetTemporary(TfLiteContext* context,
|
inline TfLiteTensor* GetTemporary(TfLiteContext* context,
|
||||||
const TfLiteNode* node, int index) {
|
const TfLiteNode* node, int index) {
|
||||||
return &context->tensors[node->temporaries->data[index]];
|
const int tensor_index = node->temporaries->data[index];
|
||||||
|
if (tensor_index < 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return &context->tensors[tensor_index];
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
|
inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
|
||||||
const TfLiteNode* node, int index) {
|
const TfLiteNode* node, int index) {
|
||||||
return &context->tensors[node->intermediates->data[index]];
|
const int tensor_index = node->intermediates->data[index];
|
||||||
|
if (tensor_index < 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return &context->tensors[tensor_index];
|
||||||
}
|
}
|
||||||
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
|
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
|
||||||
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
|
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
|
||||||
@ -73,12 +94,7 @@ inline int64_t NumElements(const TfLiteTensor* t) {
|
|||||||
inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
|
inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
|
||||||
const TfLiteNode* node,
|
const TfLiteNode* node,
|
||||||
int index) {
|
int index) {
|
||||||
const bool use_tensor = index < node->inputs->size &&
|
return GetInput(context, node, index);
|
||||||
node->inputs->data[index] != kTfLiteOptionalTensor;
|
|
||||||
if (use_tensor) {
|
|
||||||
return &context->tensors[node->inputs->data[index]];
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determines whether tensor is constant.
|
// Determines whether tensor is constant.
|
||||||
|
@ -34,11 +34,24 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
|||||||
const TfLiteTensor* data,
|
const TfLiteTensor* data,
|
||||||
const TfLiteTensor* segment_ids,
|
const TfLiteTensor* segment_ids,
|
||||||
TfLiteTensor* output) {
|
TfLiteTensor* output) {
|
||||||
int max_index = -1;
|
// Segment ids should be of same cardinality as first input dimension and they
|
||||||
|
// should be increasing by at most 1, from 0 (e.g., [0, 0, 1, 2, 3] is valid)
|
||||||
const int segment_id_size = segment_ids->dims->data[0];
|
const int segment_id_size = segment_ids->dims->data[0];
|
||||||
if (segment_id_size > 0) {
|
TF_LITE_ENSURE_EQ(context, segment_id_size, data->dims->data[0]);
|
||||||
max_index = segment_ids->data.i32[segment_id_size - 1];
|
int previous_segment_id = -1;
|
||||||
|
for (int i = 0; i < segment_id_size; i++) {
|
||||||
|
const int current_segment_id = GetTensorData<int32_t>(segment_ids)[i];
|
||||||
|
if (i == 0) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, current_segment_id, 0);
|
||||||
|
} else {
|
||||||
|
int delta = current_segment_id - previous_segment_id;
|
||||||
|
TF_LITE_ENSURE(context, delta == 0 || delta == 1);
|
||||||
}
|
}
|
||||||
|
previous_segment_id = current_segment_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int max_index = previous_segment_id;
|
||||||
|
|
||||||
const int data_rank = NumDimensions(data);
|
const int data_rank = NumDimensions(data);
|
||||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data));
|
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data));
|
||||||
output_shape->data[0] = max_index + 1;
|
output_shape->data[0] = max_index + 1;
|
||||||
|
@ -110,5 +110,37 @@ TEST(SegmentSumOpModelTest, Float32Test_ThreeDimensions) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 1}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotSorted) {
|
||||||
|
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
|
||||||
|
{TensorType_INT32, {3}});
|
||||||
|
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
|
||||||
|
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 3, 1});
|
||||||
|
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotConsecutive) {
|
||||||
|
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
|
||||||
|
{TensorType_INT32, {3}});
|
||||||
|
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
|
||||||
|
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 3, 5});
|
||||||
|
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNegative) {
|
||||||
|
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
|
||||||
|
{TensorType_INT32, {3}});
|
||||||
|
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
|
||||||
|
model.PopulateTensor<int32_t>(model.segment_ids(), {-1, 0, 1});
|
||||||
|
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotTheRightCardinality) {
|
||||||
|
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
|
||||||
|
{TensorType_INT32, {2}});
|
||||||
|
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
|
||||||
|
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1});
|
||||||
|
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -20,9 +20,11 @@ from __future__ import print_function
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
from tensorflow.python.dlpack import dlpack
|
from tensorflow.python.dlpack import dlpack
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -95,6 +97,12 @@ class DLPackTest(parameterized.TestCase, test.TestCase):
|
|||||||
self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
|
self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
|
||||||
UnsupportedComplex64)
|
UnsupportedComplex64)
|
||||||
|
|
||||||
|
def testMustPassTensorArgumentToDLPack(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
"The argument to `to_dlpack` must be a TF tensor, not Python object"):
|
||||||
|
dlpack.to_dlpack([1])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
@ -4581,6 +4581,14 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
|||||||
result = control_flow_ops.merge([v_f, v_t])
|
result = control_flow_ops.merge([v_f, v_t])
|
||||||
self.evaluate(result)
|
self.evaluate(result)
|
||||||
|
|
||||||
|
def testSwitchEagerMode(self):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
return
|
||||||
|
input_data = [1, 2, 3, 4]
|
||||||
|
vf, vt = control_flow_ops.switch(input_data, False)
|
||||||
|
self.assertAllEqual(vf, input_data)
|
||||||
|
self.assertAllEqual(vt, [])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testQIntArgAndRet(self):
|
def testQIntArgAndRet(self):
|
||||||
|
|
||||||
|
@ -25,7 +25,9 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import bincount_ops
|
from tensorflow.python.ops import bincount_ops
|
||||||
|
from tensorflow.python.ops import gen_count_ops
|
||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
@ -834,5 +836,121 @@ class TestSparseCountFailureModes(test.TestCase):
|
|||||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
@test_util.disable_tfrt
|
||||||
|
class RawOpsTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def testSparseCountSparseOutputBadIndicesShape(self):
|
||||||
|
indices = [[[0], [0]], [[0], [1]], [[1], [0]], [[1], [2]]]
|
||||||
|
values = [1, 1, 1, 10]
|
||||||
|
weights = [1, 2, 4, 6]
|
||||||
|
dense_shape = [2, 3]
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"Input indices must be a 2-dimensional tensor"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.SparseCountSparseOutput(
|
||||||
|
indices=indices,
|
||||||
|
values=values,
|
||||||
|
dense_shape=dense_shape,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
def testSparseCountSparseOutputBadWeightsShape(self):
|
||||||
|
indices = [[0, 0], [0, 1], [1, 0], [1, 2]]
|
||||||
|
values = [1, 1, 1, 10]
|
||||||
|
weights = [1, 2, 4]
|
||||||
|
dense_shape = [2, 3]
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"Weights and values must have the same shape"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.SparseCountSparseOutput(
|
||||||
|
indices=indices,
|
||||||
|
values=values,
|
||||||
|
dense_shape=dense_shape,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
def testSparseCountSparseOutputBadNumberOfValues(self):
|
||||||
|
indices = [[0, 0], [0, 1], [1, 0]]
|
||||||
|
values = [1, 1, 1, 10]
|
||||||
|
weights = [1, 2, 4, 6]
|
||||||
|
dense_shape = [2, 3]
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
"Number of values must match first dimension of indices"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.SparseCountSparseOutput(
|
||||||
|
indices=indices,
|
||||||
|
values=values,
|
||||||
|
dense_shape=dense_shape,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
def testRaggedCountSparseOutput(self):
|
||||||
|
splits = [0, 4, 7]
|
||||||
|
values = [1, 1, 2, 1, 2, 10, 5]
|
||||||
|
weights = [1, 2, 3, 4, 5, 6, 7]
|
||||||
|
output_indices, output_values, output_shape = self.evaluate(
|
||||||
|
gen_count_ops.RaggedCountSparseOutput(
|
||||||
|
splits=splits, values=values, weights=weights, binary_output=False))
|
||||||
|
self.assertAllEqual([[0, 1], [0, 2], [1, 2], [1, 5], [1, 10]],
|
||||||
|
output_indices)
|
||||||
|
self.assertAllEqual([7, 3, 5, 7, 6], output_values)
|
||||||
|
self.assertAllEqual([2, 11], output_shape)
|
||||||
|
|
||||||
|
def testRaggedCountSparseOutputBadWeightsShape(self):
|
||||||
|
splits = [0, 4, 7]
|
||||||
|
values = [1, 1, 2, 1, 2, 10, 5]
|
||||||
|
weights = [1, 2, 3, 4, 5, 6]
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"Weights and values must have the same shape"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.RaggedCountSparseOutput(
|
||||||
|
splits=splits,
|
||||||
|
values=values,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
def testRaggedCountSparseOutputEmptySplits(self):
|
||||||
|
splits = []
|
||||||
|
values = [1, 1, 2, 1, 2, 10, 5]
|
||||||
|
weights = [1, 2, 3, 4, 5, 6, 7]
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
"Must provide at least 2 elements for the splits argument"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.RaggedCountSparseOutput(
|
||||||
|
splits=splits,
|
||||||
|
values=values,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
def testRaggedCountSparseOutputBadSplitsStart(self):
|
||||||
|
splits = [1, 7]
|
||||||
|
values = [1, 1, 2, 1, 2, 10, 5]
|
||||||
|
weights = [1, 2, 3, 4, 5, 6, 7]
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"Splits must start with 0"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.RaggedCountSparseOutput(
|
||||||
|
splits=splits,
|
||||||
|
values=values,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
def testRaggedCountSparseOutputBadSplitsEnd(self):
|
||||||
|
splits = [0, 5]
|
||||||
|
values = [1, 1, 2, 1, 2, 10, 5]
|
||||||
|
weights = [1, 2, 3, 4, 5, 6, 7]
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"Splits must end with the number of values"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.RaggedCountSparseOutput(
|
||||||
|
splits=splits,
|
||||||
|
values=values,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -18,16 +18,22 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import gen_data_flow_ops
|
||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import gen_math_ops
|
||||||
|
from tensorflow.python.ops import gen_string_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class RawOpsTest(test.TestCase):
|
@test_util.disable_tfrt
|
||||||
|
class RawOpsTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def testSimple(self):
|
def testSimple(self):
|
||||||
x = constant_op.constant(1)
|
x = constant_op.constant(1)
|
||||||
@ -58,6 +64,29 @@ class RawOpsTest(test.TestCase):
|
|||||||
gen_math_ops.Any(input=x, axis=0),
|
gen_math_ops.Any(input=x, axis=0),
|
||||||
gen_math_ops.Any(input=x, axis=0, keep_dims=False))
|
gen_math_ops.Any(input=x, axis=0, keep_dims=False))
|
||||||
|
|
||||||
|
@parameterized.parameters([[0, 8]], [[-1, 6]])
|
||||||
|
def testStringNGramsBadDataSplits(self, splits):
|
||||||
|
data = ["aa", "bb", "cc", "dd", "ee", "ff"]
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"Invalid split value"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_string_ops.string_n_grams(
|
||||||
|
data=data,
|
||||||
|
data_splits=splits,
|
||||||
|
separator="",
|
||||||
|
ngram_widths=[2],
|
||||||
|
left_pad="",
|
||||||
|
right_pad="",
|
||||||
|
pad_width=0,
|
||||||
|
preserve_short_sequences=False))
|
||||||
|
|
||||||
|
def testGetSessionHandle(self):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
errors.FailedPreconditionError,
|
||||||
|
"GetSessionHandle called on null session state"):
|
||||||
|
gen_data_flow_ops.GetSessionHandle(value=[1])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -29,6 +30,7 @@ from tensorflow.python.framework import test_util
|
|||||||
# Need array_grad to register gradient for Identity.
|
# Need array_grad to register gradient for Identity.
|
||||||
from tensorflow.python.ops import array_grad # pylint: disable=unused-import
|
from tensorflow.python.ops import array_grad # pylint: disable=unused-import
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gen_sparse_ops
|
||||||
from tensorflow.python.ops import gradient_checker_v2 as gradient_checker
|
from tensorflow.python.ops import gradient_checker_v2 as gradient_checker
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
# Need sparse_grad to register gradient for SparseToDense.
|
# Need sparse_grad to register gradient for SparseToDense.
|
||||||
@ -181,5 +183,57 @@ class SparseOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(expected, result)
|
self.assertAllEqual(expected, result)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
class RawOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def testSparseFillEmptyRowsGrad(self):
|
||||||
|
reverse_index_map = [2, 1]
|
||||||
|
grad_values = [0, 1, 2, 3]
|
||||||
|
d_values, d_default_value = self.evaluate(
|
||||||
|
gen_sparse_ops.SparseFillEmptyRowsGrad(
|
||||||
|
reverse_index_map=reverse_index_map, grad_values=grad_values))
|
||||||
|
self.assertAllEqual([2, 1], d_values)
|
||||||
|
self.assertEqual(3, d_default_value)
|
||||||
|
|
||||||
|
def testSparseFillEmptyRowsGradNegativeIndexMapValue(self):
|
||||||
|
reverse_index_map = [2, -1]
|
||||||
|
grad_values = [0, 1, 2, 3]
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
r'Elements in reverse index must be in \[0, 4\)'):
|
||||||
|
self.evaluate(
|
||||||
|
gen_sparse_ops.SparseFillEmptyRowsGrad(
|
||||||
|
reverse_index_map=reverse_index_map, grad_values=grad_values))
|
||||||
|
|
||||||
|
def testSparseFillEmptyRowsGradLargeIndexMapValue(self):
|
||||||
|
reverse_index_map = [2, 10]
|
||||||
|
grad_values = [0, 1, 2, 3]
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
r'Elements in reverse index must be in \[0, 4\)'):
|
||||||
|
self.evaluate(
|
||||||
|
gen_sparse_ops.SparseFillEmptyRowsGrad(
|
||||||
|
reverse_index_map=reverse_index_map, grad_values=grad_values))
|
||||||
|
|
||||||
|
def testSparseFillEmptyRowsGradMatrix(self):
|
||||||
|
reverse_index_map = [0, 1]
|
||||||
|
grad_values = [[0, 1], [2, 3]]
|
||||||
|
# Note: Eager mode and graph mode throw different errors here. Graph mode
|
||||||
|
# will fail with a ValueError from the shape checking logic, while Eager
|
||||||
|
# will fail with an InvalidArgumentError from the kernel itself.
|
||||||
|
if context.executing_eagerly():
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
r'grad_values must be a vector'):
|
||||||
|
self.evaluate(
|
||||||
|
gen_sparse_ops.SparseFillEmptyRowsGrad(
|
||||||
|
reverse_index_map=reverse_index_map, grad_values=grad_values))
|
||||||
|
else:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'Shape must be rank 1 but is rank 2'):
|
||||||
|
self.evaluate(
|
||||||
|
gen_sparse_ops.SparseFillEmptyRowsGrad(
|
||||||
|
reverse_index_map=reverse_index_map, grad_values=grad_values))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
@ -1129,9 +1129,16 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
// DLPack functions
|
// DLPack functions
|
||||||
m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
|
m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
|
||||||
PyObject* eager_tensor_pyobject_ptr = o.ptr();
|
PyObject* eager_tensor_pyobject_ptr = o.ptr();
|
||||||
TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
|
|
||||||
tensorflow::Safe_TF_StatusPtr status =
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
tensorflow::make_safe(TF_NewStatus());
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
|
|
||||||
|
if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"The argument to `to_dlpack` must be a TF tensor, not Python object");
|
||||||
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
|
||||||
void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
|
void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
|
||||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user