From 4b628e8a154c4fbd74ed13d3284fb887a5103e41 Mon Sep 17 00:00:00 2001 From: Dero Gharibian Date: Wed, 7 Aug 2019 11:03:26 -0700 Subject: [PATCH] Updated the majority of string tensor accessors to use tstring type. This is a part of a larger migration effort for tensorflow::tstring. See: https://github.com/tensorflow/community/pull/91 PiperOrigin-RevId: 262172788 --- tensorflow/c/c_api_test.cc | 6 +- tensorflow/c/tf_tensor.cc | 4 +- tensorflow/cc/framework/cc_op_gen.cc | 2 +- tensorflow/cc/framework/ops.cc | 2 +- tensorflow/cc/saved_model/loader.cc | 4 +- tensorflow/compiler/jit/kernels/xla_ops.cc | 4 +- .../kernels/get_calibration_data_op.cc | 4 +- .../kernels/trt_engine_resource_ops.cc | 6 +- .../compiler/xrt/kernels/xrt_compile_ops.cc | 4 +- .../compiler/xrt/kernels/xrt_execute_op.cc | 6 +- .../compiler/xrt/kernels/xrt_state_ops.h | 6 +- tensorflow/compiler/xrt/tests/raw_api_test.cc | 63 ++++++++++--------- .../bigtable/kernels/bigtable_kernels.cc | 10 +-- .../kernels/bigtable_lookup_dataset_op.cc | 8 +-- .../kernels/bigtable_prefix_key_dataset_op.cc | 2 +- .../kernels/bigtable_range_key_dataset_op.cc | 2 +- .../bigtable_sample_key_pairs_dataset_op.cc | 4 +- .../bigtable_sample_keys_dataset_op.cc | 2 +- .../kernels/bigtable_scan_dataset_op.cc | 2 +- .../boosted_trees/kernels/model_ops.cc | 6 +- .../boosted_trees/kernels/quantile_ops.cc | 4 +- .../kernels/split_handler_ops.cc | 12 ++-- .../boosted_trees/kernels/training_ops.cc | 4 +- .../cloud/kernels/bigquery_reader_ops.cc | 2 +- tensorflow/contrib/ffmpeg/decode_audio_op.cc | 7 ++- tensorflow/contrib/ffmpeg/decode_video_op.cc | 3 +- tensorflow/contrib/ffmpeg/encode_audio_op.cc | 4 +- .../hadoop/kernels/hadoop_dataset_ops.cc | 6 +- .../dataset/ignite_binary_object_parser.cc | 4 +- .../kernels/input_pipeline_kernels.cc | 4 +- .../kafka/kernels/kafka_dataset_ops.cc | 4 +- .../kernels/sparse_feature_cross_kernel.cc | 12 ++-- .../libsvm/kernels/decode_libsvm_op.cc | 2 +- .../contrib/session_bundle/session_bundle.cc | 2 +- .../tensor_forest/kernels/model_ops.cc | 8 +-- .../kernels/reinterpret_string_to_float_op.cc | 2 +- .../tensor_forest/kernels/stats_ops.cc | 11 ++-- .../kernels/v4/candidate_graph_runner.cc | 2 +- .../common_runtime/direct_session_test.cc | 12 ++-- .../core/common_runtime/function_testlib.cc | 2 +- .../common_runtime/rendezvous_util_test.cc | 4 +- .../core/debug/debug_grpc_io_utils_test.cc | 6 +- tensorflow/core/debug/debug_io_utils_test.cc | 8 +-- .../core/debug/grpc_session_debug_test.cc | 6 +- .../rpc/grpc_rpc_factory.cc | 4 +- .../rpc/grpc_session_test.cc | 4 +- .../rpc/rpc_rendezvous_mgr_test.cc | 4 +- .../example/example_parser_configuration.cc | 5 +- .../core/framework/op_compatibility_test.cc | 4 +- tensorflow/core/framework/reader_base.cc | 2 +- tensorflow/core/framework/rendezvous_test.cc | 4 +- tensorflow/core/framework/resource_mgr.h | 4 +- .../core/framework/resource_op_kernel.h | 2 +- tensorflow/core/framework/tensor_test.cc | 20 +++--- tensorflow/core/framework/tensor_util.cc | 6 +- tensorflow/core/framework/tensor_util_test.cc | 35 ++++++----- .../core/framework/variant_op_copy_test.cc | 8 +-- tensorflow/core/graph/quantize_training.cc | 12 ++-- tensorflow/core/grappler/costs/utils.cc | 2 +- tensorflow/core/grappler/graph_view_test.cc | 2 +- tensorflow/core/kernels/as_string_op.cc | 2 +- tensorflow/core/kernels/barrier_ops.cc | 4 +- tensorflow/core/kernels/base64_ops.cc | 8 +-- .../kernels/boosted_trees/prediction_ops.cc | 2 +- .../kernels/boosted_trees/resource_ops.cc | 6 +- .../core/kernels/boosted_trees/stats_ops.cc | 4 +- .../kernels/conditional_accumulator_base_op.h | 2 +- .../kernels/conditional_accumulator_op.cc | 2 +- tensorflow/core/kernels/data/dataset_ops.cc | 2 +- .../data/experimental/csv_dataset_op.cc | 8 +-- .../data/experimental/lmdb_dataset_op.cc | 6 +- .../experimental/matching_files_dataset_op.cc | 4 +- .../data/experimental/prefetching_kernels.cc | 2 +- .../data/experimental/stats_aggregator_ops.cc | 2 +- .../data/experimental/to_tf_record_op.cc | 2 +- .../data/experimental/unique_dataset_op.cc | 2 +- .../data/fixed_length_record_dataset_op.cc | 8 +-- tensorflow/core/kernels/data/iterator_ops.cc | 4 +- .../kernels/data/multi_device_iterator_ops.cc | 4 +- .../core/kernels/data/text_line_dataset_op.cc | 4 +- .../core/kernels/data/tf_record_dataset_op.cc | 6 +- tensorflow/core/kernels/decode_bmp_op.cc | 2 +- .../core/kernels/decode_compressed_op.cc | 4 +- tensorflow/core/kernels/decode_csv_op.cc | 8 +-- tensorflow/core/kernels/decode_image_op.cc | 2 +- .../core/kernels/decode_padded_raw_op.cc | 2 +- tensorflow/core/kernels/decode_proto_op.cc | 8 +-- tensorflow/core/kernels/decode_raw_op.cc | 2 +- tensorflow/core/kernels/decode_wav_op.cc | 2 +- .../kernels/deserialize_sparse_string_op.cc | 4 +- tensorflow/core/kernels/encode_proto_op.cc | 4 +- .../core/kernels/example_parsing_ops.cc | 20 +++--- .../core/kernels/example_parsing_ops_test.cc | 6 +- .../core/kernels/extract_jpeg_shape_op.cc | 2 +- tensorflow/core/kernels/fact_op.cc | 2 +- tensorflow/core/kernels/fingerprint_op.cc | 4 +- .../core/kernels/fingerprint_op_test.cc | 14 ++--- tensorflow/core/kernels/function_ops.cc | 2 +- tensorflow/core/kernels/functional_ops.cc | 2 +- .../example_proto_fast_parsing_fuzz.cc | 2 +- .../core/kernels/fuzzing/fuzz_session.h | 2 +- .../kernels/fuzzing/parse_tensor_op_fuzz.cc | 2 +- .../core/kernels/fuzzing/string_split_fuzz.cc | 4 +- .../kernels/fuzzing/string_split_v2_fuzz.cc | 6 +- .../kernels/generate_vocab_remapping_op.cc | 6 +- tensorflow/core/kernels/inplace_ops.cc | 4 +- .../core/kernels/load_and_remap_matrix_op.cc | 4 +- tensorflow/core/kernels/logging_ops.cc | 2 +- .../core/kernels/lookup_table_init_op.cc | 2 +- tensorflow/core/kernels/lookup_table_op.h | 2 +- tensorflow/core/kernels/lookup_util.cc | 4 +- tensorflow/core/kernels/matching_files_op.cc | 4 +- tensorflow/core/kernels/parse_tensor_op.cc | 2 +- tensorflow/core/kernels/queue_ops.cc | 4 +- tensorflow/core/kernels/reader_ops.cc | 6 +- tensorflow/core/kernels/reduce_join_op.cc | 4 +- .../core/kernels/regex_full_match_op.cc | 6 +- tensorflow/core/kernels/regex_replace_op.cc | 8 +-- .../core/kernels/regex_replace_op_test.cc | 2 +- .../remote_fused_graph_execute_utils.cc | 2 +- tensorflow/core/kernels/restore_op_test.cc | 36 +++++------ tensorflow/core/kernels/restore_v2_op_test.cc | 24 +++---- tensorflow/core/kernels/save_op.cc | 8 +-- .../core/kernels/save_restore_tensor.cc | 20 +++--- .../core/kernels/save_restore_v2_ops.cc | 10 +-- tensorflow/core/kernels/sdca_ops.cc | 2 +- tensorflow/core/kernels/session_ops.cc | 6 +- tensorflow/core/kernels/sparse_cross_op.cc | 12 ++-- tensorflow/core/kernels/stack.cc | 6 +- tensorflow/core/kernels/string_format_op.cc | 2 +- tensorflow/core/kernels/string_join_op.cc | 4 +- tensorflow/core/kernels/string_length_op.cc | 2 +- tensorflow/core/kernels/string_lower_op.cc | 4 +- tensorflow/core/kernels/string_ngrams_op.cc | 6 +- tensorflow/core/kernels/string_split_op.cc | 10 +-- .../core/kernels/string_split_op_test.cc | 6 +- tensorflow/core/kernels/string_strip_op.cc | 4 +- .../core/kernels/string_to_hash_bucket_op.cc | 2 +- .../core/kernels/string_to_hash_bucket_op.h | 4 +- .../core/kernels/string_to_number_op.cc | 2 +- tensorflow/core/kernels/string_upper_op.cc | 4 +- tensorflow/core/kernels/substr_op.cc | 12 ++-- tensorflow/core/kernels/substr_op_test.cc | 4 +- tensorflow/core/kernels/summary_audio_op.cc | 2 +- .../core/kernels/summary_audio_op_test.cc | 4 +- tensorflow/core/kernels/summary_image_op.cc | 2 +- .../core/kernels/summary_image_op_test.cc | 6 +- tensorflow/core/kernels/summary_kernels.cc | 30 ++++----- tensorflow/core/kernels/summary_op.cc | 6 +- tensorflow/core/kernels/summary_op_test.cc | 16 ++--- .../core/kernels/summary_tensor_op_test.cc | 4 +- tensorflow/core/kernels/tensor_array.cc | 4 +- tensorflow/core/kernels/tensor_array.h | 18 +++--- tensorflow/core/kernels/tensor_array_ops.cc | 6 +- .../kernels/tensor_forest/resource_ops.cc | 6 +- tensorflow/core/kernels/unicode_ops.cc | 8 +-- .../core/kernels/unsorted_segment_join_op.cc | 4 +- .../core/kernels/whole_file_read_ops.cc | 4 +- tensorflow/core/kernels/word2vec_kernels.cc | 4 +- tensorflow/core/ops/array_ops.cc | 2 +- tensorflow/core/ops/io_ops.cc | 4 +- tensorflow/core/summary/summary_db_writer.cc | 16 ++--- .../core/summary/summary_file_writer_test.cc | 2 +- tensorflow/core/util/batch_util.cc | 4 +- .../core/util/example_proto_fast_parsing.cc | 19 +++--- .../core/util/example_proto_helper_test.cc | 14 ++--- .../core/util/sparse/sparse_tensor_test.cc | 12 ++-- tensorflow/python/framework/test_ops.cc | 8 +-- .../python/kernel_tests/ackermann_op.cc | 2 +- tensorflow/tools/benchmark/benchmark_model.cc | 2 +- .../tools/graph_transforms/sparsify_gather.cc | 4 +- .../graph_transforms/sparsify_gather_test.cc | 6 +- 172 files changed, 532 insertions(+), 525 deletions(-) diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index eb3323e7a06..ddf1f4612f1 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -234,7 +234,7 @@ void TestEncodeDecode(int line, const std::vector& data) { // Create C++ Tensor Tensor src(tensorflow::DT_STRING, TensorShape(dims)); for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) { - src.flat()(i) = data[i]; + src.flat()(i) = data[i]; } TF_Tensor* dst = TF_TensorFromTensor(src, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -244,7 +244,7 @@ void TestEncodeDecode(int line, const std::vector& data) { ASSERT_EQ(Status::OK(), TF_TensorToTensor(dst, &output)) << line; ASSERT_EQ(src.NumElements(), output.NumElements()) << line; for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) { - ASSERT_EQ(data[i], output.flat()(i)) << line; + ASSERT_EQ(data[i], output.flat()(i)) << line; } TF_DeleteTensor(dst); @@ -1386,7 +1386,7 @@ TEST(CAPI, SavedModel) { tensorflow::Example example; auto* feature_map = example.mutable_features()->mutable_feature(); (*feature_map)["x"].mutable_float_list()->add_value(i); - input.flat()(i) = example.SerializeAsString(); + input.flat()(i) = example.SerializeAsString(); } const tensorflow::string input_op_name( diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index f8d3bc220f9..44efcba99c7 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -354,7 +354,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, // Compute bytes needed for encoding. size_t size = 0; - const auto& srcarray = src.flat(); + const auto& srcarray = src.flat(); for (int i = 0; i < srcarray.size(); ++i) { const string& s = srcarray(i); // uint64 starting_offset, TF_StringEncode-d string. @@ -440,7 +440,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { const char* limit = input + src_size; *dst = Tensor(static_cast(src->dtype), src->shape); - auto dstarray = dst->flat(); + auto dstarray = dst->flat(); for (tensorflow::int64 i = 0; i < num_elements; ++i) { tensorflow::uint64 offset = reinterpret_cast(input)[i]; diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index a0353bf17a6..86f503c9e10 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -193,7 +193,7 @@ string PrintTensor(const TensorProto& tensor_proto) { string ret; for (int64 i = 0; i < num_elts; ++i) { if (i > 0) strings::StrAppend(&ret, " "); - strings::StrAppend(&ret, absl::CEscape(t.flat()(i))); + strings::StrAppend(&ret, absl::CEscape(t.flat()(i))); } return ret; } diff --git a/tensorflow/cc/framework/ops.cc b/tensorflow/cc/framework/ops.cc index 920a8e79556..8516dfd7a29 100644 --- a/tensorflow/cc/framework/ops.cc +++ b/tensorflow/cc/framework/ops.cc @@ -97,7 +97,7 @@ Input::Initializer::Initializer( Tensor elem = e.tensor; if (first.tensor.dtype() == DT_STRING) { for (int i = 0; i < elem.NumElements(); ++i) { - t.flat()(offset + i) = elem.flat()(i); + t.flat()(offset + i) = elem.flat()(i); } offset += elem.NumElements(); } else { diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index dfc7ccd9542..a3b80fbdba5 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -75,7 +75,7 @@ Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, Tensor CreateStringTensor(const string& value) { Tensor tensor(DT_STRING, TensorShape({})); - tensor.scalar()() = value; + tensor.scalar()() = value; return tensor; } @@ -219,7 +219,7 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, // Add variables to the graph. Tensor variables_path_tensor(DT_STRING, TensorShape({})); - variables_path_tensor.scalar()() = variables_path; + variables_path_tensor.scalar()() = variables_path; std::vector> inputs = { {string(variable_filename_const_op_name), variables_path_tensor}}; diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index b23980830ba..87d6548a1a7 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -508,7 +508,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { client, executable, kernel, std::move(variables), constants_.size())); Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); - compilation_key.flat()(0) = key; + compilation_key.flat()(0) = key; Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); compilation_successful.flat()(0) = true; @@ -523,7 +523,7 @@ XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaRunOp " << def().name(); Tensor key_tensor = ctx->input(ctx->num_inputs() - 1); - const XlaExecutableClosureStore::KeyT& key = key_tensor.flat()(0); + const XlaExecutableClosureStore::KeyT& key = key_tensor.flat()(0); XlaExecutableClosure closure = XlaExecutableClosureStore::Global()->Consume(key); diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc index 83a16892816..374f75c0ab9 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc @@ -40,7 +40,7 @@ class GetCalibrationDataOp : public OpKernel { // serialized string to that tensor, and later sess.run() will copy it back // to host. We need to optimize this. - const string& resource_name = context->input(0).scalar()(); + const string& resource_name = context->input(0).scalar()(); // Get the resource. TRTEngineCacheResource* resource = nullptr; OP_REQUIRES_OK(context, context->resource_manager()->Lookup( @@ -59,7 +59,7 @@ class GetCalibrationDataOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output)); - output->scalar()() = serialized_resource; + output->scalar()() = serialized_resource; } }; diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc index e28dcc1cbba..51f7e3aabc5 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc @@ -109,7 +109,7 @@ class InitializeTRTResource : public OpKernel { resource->cache_.size(), " entries.")); // Get the file name. - const string& filename = ctx->input(1).scalar()(); + const string& filename = ctx->input(1).scalar()(); OP_REQUIRES(ctx, !filename.empty(), errors::InvalidArgument("filename cannot be empty.")); @@ -171,8 +171,8 @@ class SerializeTRTResource : public OpKernel { } void Compute(OpKernelContext* ctx) override { - const string& resource_name = ctx->input(0).scalar()(); - const string& filename = ctx->input(1).scalar()(); + const string& resource_name = ctx->input(0).scalar()(); + const string& filename = ctx->input(1).scalar()(); OP_REQUIRES(ctx, !filename.empty(), errors::InvalidArgument("filename cannot be empty.")); diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index b791519c097..89daa98ee18 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -151,7 +151,7 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { xrt::XLAComputation computation_proto; OP_REQUIRES( ctx, - computation_proto.ParseFromString(computation_input.scalar()()), + computation_proto.ParseFromString(computation_input.scalar()()), errors::InvalidArgument( "Unable to parse computation input to XLAComputation")); @@ -191,7 +191,7 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { .ComputeProgramShape() .ToProto(); Tensor program_shape_output(DT_STRING, TensorShape({1})); - program_shape_output.vec()(0) = program_shape.SerializeAsString(); + program_shape_output.vec()(0) = program_shape.SerializeAsString(); ctx->set_output(1, program_shape_output); } diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 231387e314f..1c4e1f7e2c7 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -260,7 +260,7 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape())); xrt::XRTExecutionConfig config_proto; TF_RET_CHECK( - config_proto.ParseFromString(execution_config.scalar()())); + config_proto.ParseFromString(execution_config.scalar()())); int core_index_in_replica = config_proto.core_index_in_replica(); TF_RET_CHECK(core_index_in_replica == 0); @@ -343,12 +343,12 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { const Tensor& execution_plan = context->input(0); TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_plan.shape())); xrt::XRTChainedExecutePlan plan; - TF_RET_CHECK(plan.ParseFromString(execution_plan.scalar()())); + TF_RET_CHECK(plan.ParseFromString(execution_plan.scalar()())); const Tensor& execution_config = context->input(1); TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape())); xrt::XRTChainedExecuteConfig config; - TF_RET_CHECK(config.ParseFromString(execution_config.scalar()())); + TF_RET_CHECK(config.ParseFromString(execution_config.scalar()())); XRTCompilationCache* cache; TF_RETURN_IF_ERROR(rm->Lookup( diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 2ffde52af06..8afd2051c00 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -177,7 +177,7 @@ class XRTAllocateOp : public OpKernel { xrt::XLAAllocation allocation_proto; OP_REQUIRES( ctx, - allocation_proto.ParseFromString(allocation_info.scalar()()), + allocation_proto.ParseFromString(allocation_info.scalar()()), errors::InvalidArgument( "Unable to parse allocation input to XLAAllocation")); @@ -419,7 +419,7 @@ class XRTMakeTupleOp : public OpKernel { errors::Internal("tuple description input should be a string scalar")); xrt::XLATupleNode tuple_proto; OP_REQUIRES( - ctx, tuple_proto.ParseFromString(tuple_info.scalar()()), + ctx, tuple_proto.ParseFromString(tuple_info.scalar()()), errors::InvalidArgument("Unable to parse tuple input to XLATupleNode")); OpInputList arg_list; @@ -627,7 +627,7 @@ class XRTWriteLiteralOp : public OpKernel { errors::Internal("literal input should be a string scalar")); xla::LiteralProto literal_proto; OP_REQUIRES(ctx, - literal_proto.ParseFromString(literal_info.scalar()()), + literal_proto.ParseFromString(literal_info.scalar()()), errors::InvalidArgument( "Unable to parse allocation input to LiteralProto")); xla::Literal literal; diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index f0729251eeb..427a631f82d 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -127,7 +127,7 @@ xla::LiteralProto FloatMatrix( xla::Literal ReadOutputLiteral(const std::vector& outputs, size_t idx) { xla::LiteralProto response; - CHECK(response.ParseFromString(outputs[idx].scalar()())); + CHECK(response.ParseFromString(outputs[idx].scalar()())); return xla::Literal::CreateFromProto(response).ValueOrDie(); } @@ -316,7 +316,7 @@ TEST(RawApiTest, AllocFromTensor) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); } @@ -351,7 +351,7 @@ TEST(RawApiTest, AllocUninitialized) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto read_back_literal; EXPECT_TRUE( - read_back_literal.ParseFromString(outputs[0].scalar()())); + read_back_literal.ParseFromString(outputs[0].scalar()())); Tensor read_back_tensor; TF_ASSERT_OK(LiteralToHostTensor( xla::Literal::CreateFromProto(read_back_literal).ValueOrDie(), DT_FLOAT, @@ -381,7 +381,7 @@ TEST(RawApiTest, AllocUninitialized) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(response, new_literal)); } } @@ -413,7 +413,7 @@ TEST(RawApiTest, AllocFromTensorTuple) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); } @@ -439,7 +439,7 @@ TEST(RawApiTest, AllocFromTensorTupleSingle) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); } @@ -465,7 +465,7 @@ TEST(RawApiTest, AllocFromTensorRelayout) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); // We have sent literal's data (in array layout) with a attribute layout // {0,1}, so the expected literal read from device needs to be changed // accordingly. @@ -493,7 +493,7 @@ TEST(RawApiTest, AllocAndRewrite) { int64 allocation_handle = outputs[1].scalar()(); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); xla::LiteralProto new_literal = @@ -512,7 +512,7 @@ TEST(RawApiTest, AllocAndRewrite) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto new_response; - EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response)); Tensor release_tensor(DT_INT64, TensorShape({1})); @@ -652,7 +652,7 @@ TEST(RawApiTest, ReadAndWriteState) { session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); } @@ -673,7 +673,7 @@ TEST(RawApiTest, ReadAndWriteStateAutoFree) { TF_EXPECT_OK(session.Run({read_back}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); } @@ -707,13 +707,13 @@ TEST(RawApiTest, SubBuffer) { auto base_elements = base_literal.DecomposeTuple(); auto nested_0_elements = base_elements[0].Clone().DecomposeTuple(); xla::LiteralProto response_0; - EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[0], response_0)); xla::LiteralProto response_1; - EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar()())); + EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar()())); EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[1], response_1)); xla::LiteralProto response_00; - EXPECT_TRUE(response_00.ParseFromString(outputs[2].scalar()())); + EXPECT_TRUE(response_00.ParseFromString(outputs[2].scalar()())); EXPECT_TRUE(CompareLiteralToLiteralProto(nested_0_elements[0], response_00)); } @@ -779,9 +779,9 @@ TEST(RawApiTest, MakeTuple) { std::vector outputs; TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs)); xla::LiteralProto response_0; - EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar()())); xla::LiteralProto response_1; - EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar()())); + EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar()())); auto expected_0 = MakeTuple0(); EXPECT_TRUE(CompareLiteralProtos(response_0, expected_0)); @@ -853,7 +853,7 @@ TEST(RawApiTest, ExecuteChainedOpByOp) { TF_EXPECT_OK(session.Run({read_back}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({-150.0f, -36.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); @@ -973,7 +973,7 @@ TEST(RawApiTest, ExecuteChained) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({-150.0f, -36.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); @@ -1022,13 +1022,13 @@ TEST(RawApiTest, CompileAndExecute) { TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); xla::ProgramShapeProto program_shape; - EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); + EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); } @@ -1077,13 +1077,13 @@ TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); xla::ProgramShapeProto program_shape; - EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); + EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); } @@ -1128,7 +1128,8 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { {release}, &outputs)); xla::ProgramShapeProto program_shape_proto; - EXPECT_TRUE(program_shape_proto.ParseFromString(outputs[0].vec()(0))); + EXPECT_TRUE( + program_shape_proto.ParseFromString(outputs[0].vec()(0))); xla::ProgramShape program_shape(program_shape_proto); EXPECT_EQ(program_shape.parameters_size(), 1); @@ -1196,7 +1197,7 @@ TEST(RawApiTest, DotGeneralWithLayoutTest) { TF_EXPECT_OK(session.Run({read_back}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR2WithLayout({{18.0f}, {44.0f}}, layout); @@ -1231,7 +1232,7 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) { TF_EXPECT_OK(session.Run({read_back}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR0(3.0f); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); @@ -1281,7 +1282,7 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { TF_EXPECT_OK(session.Run({read_back}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto sum = xla::LiteralUtil::CreateR1({9.0f, 7.0f}); auto expected = xla::LiteralUtil::MakeTuple({&sum}); @@ -1343,7 +1344,7 @@ TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) { EXPECT_EQ(voutputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR0(kResults[i]); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); @@ -1514,13 +1515,13 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) { TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR0(15123899); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); xla::ProgramShapeProto program_shape; - EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); + EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType( xla::Shape(program_shape.result()), xla::S64)); @@ -1580,7 +1581,7 @@ TEST(RawApiTest, TestDeviceMemoryCompaction) { // we have on record. for (size_t i = 1, j = 0; i < handles.size(); i += 2, ++j) { xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[j].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[j].scalar()())); EXPECT_TRUE(CompareLiteralProtos(allocs[i].value(), response)); } } @@ -1668,7 +1669,7 @@ TEST(RawApiTest, TestDeviceMemorySwap) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto literal = xla::Literal::CreateFromProto(response).ValueOrDie(); EXPECT_EQ(literal, zero_literal); } diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc index 51b27ea4212..1e6de7ee17e 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -214,8 +214,8 @@ class ToBigtableOp : public AsyncOpKernel { std::vector columns; columns.reserve(column_families_tensor->NumElements()); for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) { - column_families.push_back(column_families_tensor->flat()(i)); - columns.push_back(columns_tensor->flat()(i)); + column_families.push_back(column_families_tensor->flat()(i)); + columns.push_back(columns_tensor->flat()(i)); } DatasetBase* dataset; @@ -317,7 +317,7 @@ class ToBigtableOp : public AsyncOpKernel { "Iterator produced a set of Tensors shorter than expected"); } ::google::cloud::bigtable::SingleRowMutation mutation( - std::move(tensors[0].scalar()())); + std::move(tensors[0].scalar()())); std::chrono::milliseconds timestamp(timestamp_int); for (size_t i = 1; i < tensors.size(); ++i) { if (!TensorShapeUtils::IsScalar(tensors[i].shape())) { @@ -326,11 +326,11 @@ class ToBigtableOp : public AsyncOpKernel { if (timestamp_int == -1) { mutation.emplace_back(::google::cloud::bigtable::SetCell( column_families[i - 1], columns[i - 1], - std::move(tensors[i].scalar()()))); + std::move(tensors[i].scalar()()))); } else { mutation.emplace_back(::google::cloud::bigtable::SetCell( column_families[i - 1], columns[i - 1], timestamp, - std::move(tensors[i].scalar()()))); + std::move(tensors[i].scalar()()))); } } bulk_mutation->emplace_back(std::move(mutation)); diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc index f1973e370a8..b341b0cae26 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -156,13 +156,13 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { ::google::cloud::StatusOr< std::pair> row = dataset()->table_->table().ReadRow( - input_tensors[0].scalar()(), dataset()->filter_); + input_tensors[0].scalar()(), dataset()->filter_); if (!row.ok()) { return GcpStatusToTfStatus(row.status()); } if (!row->first) { return errors::DataLoss("Row key '", - input_tensors[0].scalar()(), + input_tensors[0].scalar()(), "' not found."); } TF_RETURN_IF_ERROR(ParseRow(ctx, row->second, out_tensors)); @@ -180,7 +180,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { std::vector* out_tensors) { out_tensors->reserve(dataset()->columns_.size() + 1); Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); - row_key_tensor.scalar()() = string(row.row_key()); + row_key_tensor.scalar()() = tstring(row.row_key()); out_tensors->emplace_back(std::move(row_key_tensor)); if (row.cells().size() > 2 * dataset()->columns_.size()) { @@ -198,7 +198,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { if (cell_itr->family_name() == dataset()->column_families_[i] && string(cell_itr->column_qualifier()) == dataset()->columns_[i]) { - col_tensor.scalar()() = string(cell_itr->value()); + col_tensor.scalar()() = tstring(cell_itr->value()); found_column = true; } } diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc index 352e7af7de9..3908d40908d 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -99,7 +99,7 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { const ::google::cloud::bigtable::Row& row, std::vector* out_tensors) override { Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); - output_tensor.scalar()() = string(row.row_key()); + output_tensor.scalar()() = tstring(row.row_key()); out_tensors->emplace_back(std::move(output_tensor)); return Status::OK(); } diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc index 591bc786bd8..e3e6acba351 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc @@ -105,7 +105,7 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { const ::google::cloud::bigtable::Row& row, std::vector* out_tensors) override { Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); - output_tensor.scalar()() = string(row.row_key()); + output_tensor.scalar()() = string(row.row_key()); out_tensors->emplace_back(std::move(output_tensor)); return Status::OK(); } diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc index d3780f56ed8..0ca39f18670 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc @@ -177,11 +177,11 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { *end_of_sequence = false; out_tensors->emplace_back(ctx->allocator({}), DT_STRING, TensorShape({})); - out_tensors->back().scalar()() = keys_[index_]; + out_tensors->back().scalar()() = keys_[index_]; out_tensors->emplace_back(ctx->allocator({}), DT_STRING, TensorShape({})); - out_tensors->back().scalar()() = keys_[index_ + 1]; + out_tensors->back().scalar()() = keys_[index_ + 1]; ++index_; return Status::OK(); diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc index 967920ef3dc..513514f63c1 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc @@ -99,7 +99,7 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel { if (index_ < row_keys_.size()) { out_tensors->emplace_back(ctx->allocator({}), DT_STRING, TensorShape({})); - out_tensors->back().scalar()() = + out_tensors->back().scalar()() = string(row_keys_[index_].row_key); *end_of_sequence = false; index_++; diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc index c94e66057c9..d5071537b9b 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc @@ -177,7 +177,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel { std::vector* out_tensors) override { out_tensors->reserve(dataset()->columns_.size() + 1); Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); - row_key_tensor.scalar()() = string(row.row_key()); + row_key_tensor.scalar()() = string(row.row_key()); out_tensors->emplace_back(std::move(row_key_tensor)); if (row.cells().size() > 2 * dataset()->columns_.size()) { diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc index 9655e49d91b..5f9976a491c 100644 --- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc @@ -46,7 +46,7 @@ class CreateTreeEnsembleVariableOp : public OpKernel { OP_REQUIRES_OK(context, context->input("tree_ensemble_config", &tree_ensemble_config_t)); auto* result = new DecisionTreeEnsembleResource(); - if (!result->InitFromSerialized(tree_ensemble_config_t->scalar()(), + if (!result->InitFromSerialized(tree_ensemble_config_t->scalar()(), stamp_token)) { result->Unref(); OP_REQUIRES( @@ -99,7 +99,7 @@ class TreeEnsembleSerializeOp : public OpKernel { Tensor* output_config_t = nullptr; OP_REQUIRES_OK( context, context->allocate_output(1, TensorShape(), &output_config_t)); - output_config_t->scalar()() = + output_config_t->scalar()() = ensemble_resource->SerializeAsString(); } }; @@ -130,7 +130,7 @@ class TreeEnsembleDeserializeOp : public OpKernel { OP_REQUIRES( context, ensemble_resource->InitFromSerialized( - tree_ensemble_config_t->scalar()(), stamp_token), + tree_ensemble_config_t->scalar()(), stamp_token), errors::InvalidArgument("Unable to parse tree ensemble config.")); } }; diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 431dc68836b..bea5c2a839a 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -324,7 +324,7 @@ class QuantileAccumulatorAddSummariesOp : public OpKernel { context, ParseProtoUnlimited( summary_proto, - summary_list[resource_handle_idx].scalar()()), + summary_list[resource_handle_idx].scalar()()), errors::InvalidArgument("Unable to parse quantile summary.")); std::vector entries; entries.reserve(summary_proto->entries_size()); @@ -543,7 +543,7 @@ class QuantileAccumulatorDeserializeOp : public OpKernel { ::boosted_trees::QuantileStreamState state_proto; OP_REQUIRES( context, - ParseProtoUnlimited(&state_proto, stream_state_t->scalar()()), + ParseProtoUnlimited(&state_proto, stream_state_t->scalar()()), errors::InvalidArgument("Unabnle to parse quantile stream state.")); std::vector summaries; summaries.reserve(state_proto.summaries_size()); diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 65276242aba..4e96957130e 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -213,8 +213,8 @@ class BuildDenseInequalitySplitsOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output("split_infos", TensorShape({size_output}), &output_splits_t)); - tensorflow::TTypes::Vec output_splits = - output_splits_t->vec(); + tensorflow::TTypes::Vec output_splits = + output_splits_t->vec(); if (num_elements == 0) { return; @@ -529,8 +529,8 @@ class BuildSparseInequalitySplitsOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output( "split_infos", TensorShape({num_elements}), &output_splits_t)); - tensorflow::TTypes::Vec output_splits = - output_splits_t->vec(); + tensorflow::TTypes::Vec output_splits = + output_splits_t->vec(); SplitBuilderState state(context); // For each tree node that needs to be split. for (int root_idx = 0; root_idx < num_elements; ++root_idx) { @@ -780,8 +780,8 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output("split_infos", TensorShape({size_output}), &output_splits_t)); - tensorflow::TTypes::Vec output_splits = - output_splits_t->vec(); + tensorflow::TTypes::Vec output_splits = + output_splits_t->vec(); if (num_elements == 0) { return; } diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 8cb5cfbd3dd..bf5f5d34457 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -468,7 +468,7 @@ class GrowTreeEnsembleOp : public OpKernel { for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) { const auto& partition_ids = partition_ids_list[handler_id].vec(); const auto& gains = gains_list[handler_id].vec(); - const auto& splits = splits_list[handler_id].vec(); + const auto& splits = splits_list[handler_id].vec(); OP_REQUIRES(context, partition_ids.size() == gains.size(), errors::InvalidArgument( "Inconsistent partition Ids and gains tensors: ", @@ -502,7 +502,7 @@ class GrowTreeEnsembleOp : public OpKernel { // Find best split per partition going through every feature candidate. for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) { const auto& gains = gains_list[handler_id].vec(); - const auto& splits = splits_list[handler_id].vec(); + const auto& splits = splits_list[handler_id].vec(); OP_REQUIRES(context, gains.size() == 1, errors::InvalidArgument( "Gains size must be one for oblivious weak learner: ", diff --git a/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc b/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc index b0f9237ea27..7a19a1c9231 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc @@ -153,7 +153,7 @@ class GenerateBigQueryReaderPartitionsOp : public OpKernel { context->allocate_output(0, TensorShape({num_partitions_}), &output_tensor)); - auto output = output_tensor->template flat(); + auto output = output_tensor->template flat(); for (int64 i = 0; i < num_partitions_; ++i) { BigQueryTablePartition partition; partition.set_start_index(i * partition_size); diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op.cc b/tensorflow/contrib/ffmpeg/decode_audio_op.cc index ca65ad45326..32e62a6725f 100644 --- a/tensorflow/contrib/ffmpeg/decode_audio_op.cc +++ b/tensorflow/contrib/ffmpeg/decode_audio_op.cc @@ -135,9 +135,10 @@ class DecodeAudioOpV2 : public OpKernel { "channel_count must be a rank-0 tensor but got shape ", channel_count_tensor.shape().DebugString())); - const tensorflow::StringPiece contents = contents_tensor.scalar()(); + const tensorflow::StringPiece contents = + contents_tensor.scalar()(); const string file_format = - absl::AsciiStrToLower(file_format_tensor.scalar()()); + absl::AsciiStrToLower(file_format_tensor.scalar()()); const int32 samples_per_second = samples_per_second_tensor.scalar()(); const int32 channel_count = channel_count_tensor.scalar()(); @@ -243,7 +244,7 @@ class DecodeAudioOp : public OpKernel { errors::InvalidArgument("contents must be scalar but got shape ", contents.shape().DebugString())); - const tensorflow::StringPiece file_contents = contents.scalar()(); + const tensorflow::StringPiece file_contents = contents.scalar()(); Decode(context, file_contents, file_format_, samples_per_second_, channel_count_, ""); } diff --git a/tensorflow/contrib/ffmpeg/decode_video_op.cc b/tensorflow/contrib/ffmpeg/decode_video_op.cc index 6f8ad486d10..0bfdc2781aa 100644 --- a/tensorflow/contrib/ffmpeg/decode_video_op.cc +++ b/tensorflow/contrib/ffmpeg/decode_video_op.cc @@ -45,7 +45,8 @@ class DecodeVideoOp : public OpKernel { errors::InvalidArgument( "contents must be a rank-0 tensor but got shape ", contents_tensor.shape().DebugString())); - const tensorflow::StringPiece contents = contents_tensor.scalar()(); + const tensorflow::StringPiece contents = + contents_tensor.scalar()(); // Write the input data to a temp file. string extension; diff --git a/tensorflow/contrib/ffmpeg/encode_audio_op.cc b/tensorflow/contrib/ffmpeg/encode_audio_op.cc index 7de09e062ec..ee418fb9020 100644 --- a/tensorflow/contrib/ffmpeg/encode_audio_op.cc +++ b/tensorflow/contrib/ffmpeg/encode_audio_op.cc @@ -45,7 +45,7 @@ void Encode(OpKernelContext* context, const Tensor& contents, // Copy the encoded audio file to the output tensor. Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output)); - output->scalar()() = encoded_audio; + output->scalar()() = encoded_audio; } } // namespace @@ -95,7 +95,7 @@ class EncodeAudioOpV2 : public OpKernel { bits_per_second_tensor.shape().DebugString())); const string file_format = - absl::AsciiStrToLower(file_format_tensor.scalar()()); + absl::AsciiStrToLower(file_format_tensor.scalar()()); const int32 samples_per_second = samples_per_second_tensor.scalar()(); const int32 bits_per_second = bits_per_second_tensor.scalar()(); diff --git a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc index a4084e79753..48491b9f051 100644 --- a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc +++ b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc @@ -198,7 +198,7 @@ class SequenceFileDatasetOp : public DatasetOpKernel { std::vector filenames; filenames.reserve(filenames_tensor->NumElements()); for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.push_back(filenames_tensor->flat()(i)); + filenames.push_back(filenames_tensor->flat()(i)); } *output = new Dataset(ctx, filenames, output_types_); @@ -264,11 +264,11 @@ class SequenceFileDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR(status); Tensor key_tensor(ctx->allocator({}), DT_STRING, {}); - key_tensor.scalar()() = key; + key_tensor.scalar()() = key; out_tensors->emplace_back(std::move(key_tensor)); Tensor value_tensor(ctx->allocator({}), DT_STRING, {}); - value_tensor.scalar()() = value; + value_tensor.scalar()() = value; out_tensors->emplace_back(std::move(value_tensor)); *end_of_sequence = false; diff --git a/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc index 4218ec05f2c..41c9a8b1f49 100644 --- a/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc @@ -73,7 +73,7 @@ Status BinaryObjectParser::Parse(uint8_t** ptr, } case STRING: { out_tensors->emplace_back(cpu_allocator(), DT_STRING, TensorShape({})); - out_tensors->back().scalar()() = ParseString(ptr); + out_tensors->back().scalar()() = ParseString(ptr); break; } case DATE: { @@ -150,7 +150,7 @@ Status BinaryObjectParser::Parse(uint8_t** ptr, out_tensors->emplace_back(cpu_allocator(), DT_STRING, TensorShape({length})); for (int32_t i = 0; i < length; i++) - out_tensors->back().vec()(i) = ParseString(ptr); + out_tensors->back().vec()(i) = ParseString(ptr); break; } case DATE_ARR: { diff --git a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc index 886f6798150..d5da76a753f 100644 --- a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc +++ b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc @@ -30,7 +30,7 @@ class ObtainNextOp : public OpKernel { const Tensor* list; OP_REQUIRES_OK(ctx, ctx->input("list", &list)); int64 num_elements = list->NumElements(); - auto list_flat = list->flat(); + auto list_flat = list->flat(); // Allocate output. Tensor* output_tensor = nullptr; @@ -48,7 +48,7 @@ class ObtainNextOp : public OpKernel { *pos = (*pos + 1) % num_elements; // Assign value to output. - output_tensor->scalar()() = list_flat(*pos); + output_tensor->scalar()() = list_flat(*pos); } }; diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc index bb0d4c178dc..8e0e7133686 100644 --- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -33,7 +33,7 @@ class KafkaDatasetOp : public DatasetOpKernel { std::vector topics; topics.reserve(topics_tensor->NumElements()); for (int i = 0; i < topics_tensor->NumElements(); ++i) { - topics.push_back(topics_tensor->flat()(i)); + topics.push_back(topics_tensor->flat()(i)); } std::string servers = ""; @@ -128,7 +128,7 @@ class KafkaDatasetOp : public DatasetOpKernel { if (message->err() == RdKafka::ERR_NO_ERROR) { // Produce the line as output. Tensor line_tensor(cpu_allocator(), DT_STRING, {}); - line_tensor.scalar()() = + line_tensor.scalar()() = std::string(static_cast(message->payload()), message->len()); out_tensors->emplace_back(std::move(line_tensor)); diff --git a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc index ee4b0373ef7..0923bdd32bb 100644 --- a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc +++ b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc @@ -78,7 +78,7 @@ template <> int64 SparseTensorColumn::Feature(int64 batch, int64 n) const { const int64 start = feature_start_indices_[batch]; if (DT_STRING == values_.dtype()) - return Fingerprint64(values_.vec().data()[start + n]); + return Fingerprint64(values_.vec().data()[start + n]); return values_.vec().data()[start + n]; } @@ -87,7 +87,7 @@ template <> string SparseTensorColumn::Feature(int64 batch, int64 n) const { const int64 start = feature_start_indices_[batch]; if (DT_STRING == values_.dtype()) - return values_.vec().data()[start + n]; + return values_.vec().data()[start + n]; return std::to_string(values_.vec().data()[start + n]); } @@ -95,7 +95,7 @@ template <> StringPiece SparseTensorColumn::Feature(int64 batch, int64 n) const { const int64 start = feature_start_indices_[batch]; - return values_.vec().data()[start + n]; + return values_.vec().data()[start + n]; } // A column that is backed by a dense tensor. @@ -118,21 +118,21 @@ class DenseTensorColumn : public ColumnInterface { template <> int64 DenseTensorColumn::Feature(int64 batch, int64 n) const { if (DT_STRING == tensor_.dtype()) - return Fingerprint64(tensor_.matrix()(batch, n)); + return Fingerprint64(tensor_.matrix()(batch, n)); return tensor_.matrix()(batch, n); } // Internal type is string or StringPiece when using StringCrosser. template <> string DenseTensorColumn::Feature(int64 batch, int64 n) const { - if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); + if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); return std::to_string(tensor_.matrix()(batch, n)); } template <> StringPiece DenseTensorColumn::Feature(int64 batch, int64 n) const { - return tensor_.matrix()(batch, n); + return tensor_.matrix()(batch, n); } // Updates Output tensors with sparse crosses. diff --git a/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc index 720c74e3de5..f35453f267e 100644 --- a/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc +++ b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc @@ -36,7 +36,7 @@ class DecodeLibsvmOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor* input_tensor; OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); - const auto& input_flat = input_tensor->flat(); + const auto& input_flat = input_tensor->flat(); Tensor* label_tensor; OP_REQUIRES_OK( diff --git a/tensorflow/contrib/session_bundle/session_bundle.cc b/tensorflow/contrib/session_bundle/session_bundle.cc index a690d9b129a..996e4ce0b80 100644 --- a/tensorflow/contrib/session_bundle/session_bundle.cc +++ b/tensorflow/contrib/session_bundle/session_bundle.cc @@ -72,7 +72,7 @@ Status GetMetaGraphDefFromExport(const StringPiece export_dir, // Creates a string tensor. Tensor CreateStringTensor(const string& value) { Tensor tensor(DT_STRING, TensorShape({})); - tensor.scalar()() = value; + tensor.scalar()() = value; return tensor; } diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc index 94650fe108b..5f997c2fba0 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc @@ -52,7 +52,7 @@ class CreateTreeVariableOp : public OpKernel { auto* result = new DecisionTreeResource(param_proto_); if (!ParseProtoUnlimited(result->mutable_decision_tree(), - tree_config_t->scalar()())) { + tree_config_t->scalar()())) { result->Unref(); OP_REQUIRES(context, false, errors::InvalidArgument("Unable to parse tree config.")); @@ -85,7 +85,7 @@ class TreeSerializeOp : public OpKernel { Tensor* output_config_t = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, TensorShape(), &output_config_t)); - output_config_t->scalar()() = + output_config_t->scalar()() = decision_tree_resource->decision_tree().SerializeAsString(); } }; @@ -116,7 +116,7 @@ class TreeDeserializeOp : public OpKernel { decision_trees::Model* config = decision_tree_resource->mutable_decision_tree(); OP_REQUIRES(context, - ParseProtoUnlimited(config, tree_config_t->scalar()()), + ParseProtoUnlimited(config, tree_config_t->scalar()()), errors::InvalidArgument("Unable to parse tree config.")); decision_tree_resource->MaybeInitialize(); } @@ -224,7 +224,7 @@ class TreePredictionsV4Op : public OpKernel { : 0); OP_REQUIRES_OK(context, context->allocate_output(1, output_paths_shape, &output_tree_paths)); - auto out_paths = output_tree_paths->unaligned_flat(); + auto out_paths = output_tree_paths->unaligned_flat(); // TODO(gilberth): If this slows down inference too much, consider having // a filter that only serializes paths for the predicted label that we're diff --git a/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc b/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc index b21a9179777..fcea240dee9 100644 --- a/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc +++ b/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc @@ -38,7 +38,7 @@ float Convert(const string& in) { void Evaluate(const Tensor& input_data, Tensor output_data, int32 start, int32 end) { auto out_data = output_data.unaligned_flat(); - const auto in_data = input_data.unaligned_flat(); + const auto in_data = input_data.unaligned_flat(); for (int32 i = start; i < end; ++i) { out_data(i) = Convert(in_data(i)); diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc index ede6e1abc9f..e4693cf68dc 100644 --- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc @@ -56,7 +56,7 @@ class CreateFertileStatsVariableOp : public OpKernel { errors::InvalidArgument("Stats config must be a scalar.")); auto* result = new FertileStatsResource(param_proto_); FertileStats stats; - if (!ParseProtoUnlimited(&stats, stats_config_t->scalar()())) { + if (!ParseProtoUnlimited(&stats, stats_config_t->scalar()())) { result->Unref(); OP_REQUIRES(context, false, errors::InvalidArgument("Unable to parse stats config.")); @@ -98,7 +98,7 @@ class FertileStatsSerializeOp : public OpKernel { FertileStats stats; fertile_stats_resource->PackToProto(&stats); - output_config_t->scalar()() = stats.SerializeAsString(); + output_config_t->scalar()() = stats.SerializeAsString(); } private: @@ -128,9 +128,10 @@ class FertileStatsDeserializeOp : public OpKernel { // Deallocate all the previous objects on the resource. fertile_stats_resource->Reset(); FertileStats stats; - OP_REQUIRES(context, - ParseProtoUnlimited(&stats, stats_config_t->scalar()()), - errors::InvalidArgument("Unable to parse stats config.")); + OP_REQUIRES( + context, + ParseProtoUnlimited(&stats, stats_config_t->scalar()()), + errors::InvalidArgument("Unable to parse stats config.")); fertile_stats_resource->ExtractFromProto(stats); fertile_stats_resource->MaybeInitialize(); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc index f4a7058ddb8..417cb6f7420 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc @@ -103,7 +103,7 @@ float CandidateGraphRunner::SplitScore() { void CandidateGraphRunner::GetSplit(decision_trees::BinaryNode* node) { std::vector outputs; RunOp(kNoOp, TensorNameValueList(), {kGetSplitName}, &outputs); - ParseProtoUnlimited(node, outputs[0].unaligned_flat()(0)); + ParseProtoUnlimited(node, outputs[0].unaligned_flat()(0)); const auto& oblique = split_.inequality_left_child_test().oblique(); auto* new_split = node->mutable_inequality_left_child_test()->mutable_oblique(); diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 8da13aaca22..b073d1ae568 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -1055,9 +1055,9 @@ class SessionMetadataReaderOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_output("y", TensorShape({}), &out_tensor)); if (ctx->session_metadata() != nullptr) { - out_tensor->scalar()() = ctx->session_metadata()->DebugString(); + out_tensor->scalar()() = ctx->session_metadata()->DebugString(); } else { - out_tensor->scalar()() = ""; + out_tensor->scalar()() = ""; } } }; @@ -1079,7 +1079,7 @@ TEST(DirectSessionTest, SessionMetadataAbsent) { run_opts.set_inter_op_thread_pool(-1); auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr); - EXPECT_EQ("", outputs[0].scalar()()); + EXPECT_EQ("", outputs[0].scalar()()); } TEST(DirectSessionTest, SessionMetadataPresent) { @@ -1104,7 +1104,7 @@ TEST(DirectSessionTest, SessionMetadataPresent) { SessionMetadata read_metadata; ASSERT_TRUE(protobuf::TextFormat::ParseFromString( - outputs[0].scalar()(), &read_metadata)); + outputs[0].scalar()(), &read_metadata)); EXPECT_EQ("name", read_metadata.name()); EXPECT_EQ(1, read_metadata.version()); } @@ -1468,7 +1468,7 @@ TEST(DirectSessionTest, RunHandleTest) { const ResourceHandle& resource_handle = outputs[0].scalar()(); Tensor string_handle(DT_STRING, {}); - string_handle.flat().setConstant(resource_handle.name()); + string_handle.flat().setConstant(resource_handle.name()); // Second run call: Use a handle. std::vector outputs1; @@ -1521,7 +1521,7 @@ TEST(DirectSessionTest, RunHandleTest_Callable) { const ResourceHandle& resource_handle = outputs[0].scalar()(); Tensor string_handle(DT_STRING, {}); - string_handle.flat().setConstant(resource_handle.name()); + string_handle.flat().setConstant(resource_handle.name()); // Second run call: Use a handle. std::vector outputs1; diff --git a/tensorflow/core/common_runtime/function_testlib.cc b/tensorflow/core/common_runtime/function_testlib.cc index 1720ee64c07..bbaa94d6143 100644 --- a/tensorflow/core/common_runtime/function_testlib.cc +++ b/tensorflow/core/common_runtime/function_testlib.cc @@ -33,7 +33,7 @@ class FindDeviceOpKernel : public OpKernel { Tensor* device_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output("device_name", TensorShape{}, &device_tensor)); - device_tensor->scalar()() = + device_tensor->scalar()() = ctx->function_library()->device()->name(); } }; diff --git a/tensorflow/core/common_runtime/rendezvous_util_test.cc b/tensorflow/core/common_runtime/rendezvous_util_test.cc index 093fa7921f5..cb3fc45499d 100644 --- a/tensorflow/core/common_runtime/rendezvous_util_test.cc +++ b/tensorflow/core/common_runtime/rendezvous_util_test.cc @@ -33,7 +33,7 @@ class RendezvousUtilTest : public ::testing::Test { // string -> Tensor Tensor V(const string& content) { Tensor tensor(DT_STRING, TensorShape({})); - tensor.scalar()() = content; + tensor.scalar()() = content; return tensor; } @@ -41,7 +41,7 @@ Tensor V(const string& content) { string V(const Tensor& tensor) { CHECK_EQ(tensor.dtype(), DT_STRING); CHECK(TensorShapeUtils::IsScalar(tensor.shape())); - return tensor.scalar()(); + return tensor.scalar()(); } string MakeStringKey(const string& name) { diff --git a/tensorflow/core/debug/debug_grpc_io_utils_test.cc b/tensorflow/core/debug/debug_grpc_io_utils_test.cc index c857f12e755..26fd376cc6a 100644 --- a/tensorflow/core/debug/debug_grpc_io_utils_test.cc +++ b/tensorflow/core/debug/debug_grpc_io_utils_test.cc @@ -147,7 +147,7 @@ TEST_F(GrpcDebugTest, SendSingleDebugTensorViaGrpcTest) { TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex0ViaGrpcTest) { Tensor tensor(DT_STRING, TensorShape({1, 1})); - tensor.flat()(0) = string(5000 * 1024, 'A'); + tensor.flat()(0) = string(5000 * 1024, 'A'); const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo_tensor", 0, "DebugIdentity"); const Status status = DebugIO::PublishDebugTensor( @@ -162,8 +162,8 @@ TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex0ViaGrpcTest) { TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex1ViaGrpcTest) { Tensor tensor(DT_STRING, TensorShape({1, 2})); - tensor.flat()(0) = "A"; - tensor.flat()(1) = string(5000 * 1024, 'A'); + tensor.flat()(0) = "A"; + tensor.flat()(1) = string(5000 * 1024, 'A'); const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo_tensor", 0, "DebugIdentity"); const Status status = DebugIO::PublishDebugTensor( diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc index 928a82b0611..3eebcb3f138 100644 --- a/tensorflow/core/debug/debug_io_utils_test.cc +++ b/tensorflow/core/debug/debug_io_utils_test.cc @@ -47,8 +47,8 @@ class DebugIOUtilsTest : public ::testing::Test { tensor_a_->flat()(3) = 0.0; tensor_b_.reset(new Tensor(DT_STRING, TensorShape{2})); - tensor_b_->flat()(0) = "corge"; - tensor_b_->flat()(1) = "garply"; + tensor_b_->flat()(0) = "corge"; + tensor_b_->flat()(1) = "garply"; } Env* env_; @@ -182,8 +182,8 @@ TEST_F(DebugIOUtilsTest, DumpStringTensorToFileSunnyDay) { // Verify tensor shape and value. ASSERT_EQ(tensor_b_->shape(), b_prime.shape()); - for (int i = 0; i < b_prime.flat().size(); ++i) { - ASSERT_EQ(tensor_b_->flat()(i), b_prime.flat()(i)); + for (int i = 0; i < b_prime.flat().size(); ++i) { + ASSERT_EQ(tensor_b_->flat()(i), b_prime.flat()(i)); } // Tear down temporary file and directories. diff --git a/tensorflow/core/debug/grpc_session_debug_test.cc b/tensorflow/core/debug/grpc_session_debug_test.cc index 642a2a4c07d..65ec1ef8a6d 100644 --- a/tensorflow/core/debug/grpc_session_debug_test.cc +++ b/tensorflow/core/debug/grpc_session_debug_test.cc @@ -231,7 +231,7 @@ TEST_F(GrpcSessionDebugTest, MultiDevices_String) { Graph graph(OpRegistry::Global()); Tensor a_tensor(DT_STRING, TensorShape({2, 2})); for (size_t i = 0; i < 4; ++i) { - a_tensor.flat()(i) = "hello, world"; + a_tensor.flat()(i) = "hello, world"; } Node* a = test::graph::Constant(&graph, a_tensor); Node* b = test::graph::Identity(&graph, a); @@ -266,7 +266,7 @@ TEST_F(GrpcSessionDebugTest, MultiDevices_String) { ASSERT_EQ(outputs[0].dtype(), DT_STRING); ASSERT_EQ(outputs[0].NumElements(), 4); for (size_t i = 0; i < outputs[0].NumElements(); ++i) { - EXPECT_EQ(outputs[0].flat()(i), "hello, world"); + EXPECT_EQ(outputs[0].flat()(i), "hello, world"); } TF_CHECK_OK(session->Close()); @@ -278,7 +278,7 @@ TEST_F(GrpcSessionDebugTest, MultiDevices_String) { ASSERT_EQ(1, dumped_tensors.size()); ASSERT_EQ(TensorShape({2, 2}), dumped_tensors[0].shape()); for (size_t i = 0; i < 4; ++i) { - ASSERT_EQ("hello, world", dumped_tensors[0].flat()(i)); + ASSERT_EQ("hello, world", dumped_tensors[0].flat()(i)); } DeleteDumpDir(); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc index 3635caf3d10..8be6f1d6994 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc @@ -187,8 +187,8 @@ void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc, void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t, GrpcCall* call) { - auto address = address_t.flat(); - auto method = method_t.flat(); + auto address = address_t.flat(); + auto method = method_t.flat(); // Stubs are maintained by the GrpcRPCFactory class and will be // deleted when the class is destroyed. ::grpc::GenericStub* singleton_stub = nullptr; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc index c38b89b9c6f..7f2906efca6 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc @@ -501,7 +501,7 @@ TEST(GrpcSessionTest, MultiDevices_String) { Graph graph(OpRegistry::Global()); Tensor a_tensor(DT_STRING, TensorShape({2, 2})); for (int i = 0; i < 4; ++i) { - a_tensor.flat()(i) = "hello, world"; + a_tensor.flat()(i) = "hello, world"; } Node* a = test::graph::Constant(&graph, a_tensor); Node* b = test::graph::Identity(&graph, a); @@ -525,7 +525,7 @@ TEST(GrpcSessionTest, MultiDevices_String) { ASSERT_EQ(outputs[0].dtype(), DT_STRING); ASSERT_EQ(outputs[0].NumElements(), 4); for (int i = 0; i < outputs[0].NumElements(); ++i) { - EXPECT_EQ(outputs[0].flat()(i), "hello, world"); + EXPECT_EQ(outputs[0].flat()(i), "hello, world"); } TF_CHECK_OK(session->Close()); } else { diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index 1483a65e3d9..5021853ce23 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -30,7 +30,7 @@ namespace tensorflow { // string -> Tensor Tensor V(const string& content) { Tensor tensor(DT_STRING, TensorShape({})); - tensor.scalar()() = content; + tensor.scalar()() = content; return tensor; } @@ -38,7 +38,7 @@ Tensor V(const string& content) { string V(const Tensor& tensor) { CHECK_EQ(tensor.dtype(), DT_STRING); CHECK(TensorShapeUtils::IsScalar(tensor.shape())); - return tensor.scalar()(); + return tensor.scalar()(); } Rendezvous::ParsedKey MakeKey(const string& s) { diff --git a/tensorflow/core/example/example_parser_configuration.cc b/tensorflow/core/example/example_parser_configuration.cc index 5660465c51a..af06c07eac9 100644 --- a/tensorflow/core/example/example_parser_configuration.cc +++ b/tensorflow/core/example/example_parser_configuration.cc @@ -114,13 +114,14 @@ Status ExtractExampleParserConfiguration( for (int i = 0; i < num_sparse; ++i) { int input_idx = sparse_keys_start + i; - (*var_len_features)[i].key = op_input_tensors[input_idx].scalar()(); + (*var_len_features)[i].key = + op_input_tensors[input_idx].scalar()(); } for (int i = 0; i < num_dense; ++i) { FixedLenFeature& config = (*fixed_len_features)[i]; int dense_keys_offset = dense_keys_start + i; - config.key = op_input_tensors[dense_keys_offset].scalar()(); + config.key = op_input_tensors[dense_keys_offset].scalar()(); int defaults_offset = dense_defaults_start + i; config.default_value = op_input_tensors[defaults_offset]; diff --git a/tensorflow/core/framework/op_compatibility_test.cc b/tensorflow/core/framework/op_compatibility_test.cc index dc931c38cd5..4edb60786d7 100644 --- a/tensorflow/core/framework/op_compatibility_test.cc +++ b/tensorflow/core/framework/op_compatibility_test.cc @@ -35,7 +35,7 @@ class TestKernel : public OpKernel { Tensor* out_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output("ndef", TensorShape({}), &out_tensor)); - out_tensor->scalar()() = SummarizeNodeDef(def()); + out_tensor->scalar()() = SummarizeNodeDef(def()); } }; @@ -87,7 +87,7 @@ class OpCompatibilityTest : public OpsTestBase { TF_ASSERT_OK(RunOpKernel()); } - string Result() { return GetOutput(0)->scalar()(); } + string Result() { return GetOutput(0)->scalar()(); } void ExpectIncompatible(const OpDef& old_op_def, const OpDef& new_op_def, const string& error) { diff --git a/tensorflow/core/framework/reader_base.cc b/tensorflow/core/framework/reader_base.cc index 39d83d9633b..ec27b8b89cb 100644 --- a/tensorflow/core/framework/reader_base.cc +++ b/tensorflow/core/framework/reader_base.cc @@ -214,7 +214,7 @@ string ReaderBase::GetNextWorkLocked(QueueInterface* queue, context->SetStatus(errors::InvalidArgument( "Expected to dequeue a one-element string tensor")); } else { - work = tuple[0].flat()(0); + work = tuple[0].flat()(0); } } n.Notify(); diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc index 1c392fc8323..da9a1fbbe89 100644 --- a/tensorflow/core/framework/rendezvous_test.cc +++ b/tensorflow/core/framework/rendezvous_test.cc @@ -86,7 +86,7 @@ class LocalRendezvousTest : public ::testing::Test { // string -> Tensor Tensor V(const string& content) { Tensor tensor(DT_STRING, TensorShape({})); - tensor.scalar()() = content; + tensor.scalar()() = content; return tensor; } @@ -94,7 +94,7 @@ Tensor V(const string& content) { string V(const Tensor& tensor) { CHECK_EQ(tensor.dtype(), DT_STRING); CHECK(TensorShapeUtils::IsScalar(tensor.shape())); - return tensor.scalar()(); + return tensor.scalar()(); } Rendezvous::ParsedKey MakeKey(const string& name) { diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index 301fe686df1..67ea803511c 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -639,8 +639,8 @@ Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name, "Resource handle must have 2 elements, but had shape: ", tensor.shape().DebugString()); } - container = tensor.flat()(0); - shared_name = tensor.flat()(1); + container = tensor.flat()(0); + shared_name = tensor.flat()(1); } return ctx->resource_manager()->Lookup(container, shared_name, resource); } diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h index fbcd439dea3..60e9703190c 100644 --- a/tensorflow/core/framework/resource_op_kernel.h +++ b/tensorflow/core/framework/resource_op_kernel.h @@ -96,7 +96,7 @@ class ResourceOpKernel : public OpKernel { } if (!has_resource_type_) { - auto h = handle_.AccessTensor(context)->template flat(); + auto h = handle_.AccessTensor(context)->template flat(); h(0) = cinfo_.container(); h(1) = cinfo_.name(); } diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index d4aed387610..dd4ca706f28 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -480,7 +480,7 @@ TEST_F(TensorReshapeTest, ReshapeError) { Tensor string_tensor{DT_STRING, {10}}; // Note that the error message compare # of elements, not # of bytes. - EXPECT_DEATH((string_tensor.bit_casted_shaped({9})), "9 vs. 10"); + EXPECT_DEATH((string_tensor.bit_casted_shaped({9})), "9 vs. 10"); } TEST_F(TensorReshapeTest, Flat) { @@ -795,27 +795,27 @@ TEST(Tensor_Scalar, Basics) { { Tensor t(DT_STRING, TensorShape({})); EXPECT_EQ(1, t.NumElements()); - auto Tt = t.scalar(); + auto Tt = t.scalar(); EXPECT_EQ(1, Tt.size()); EXPECT_EQ(0, Tt.rank()); - t.scalar()() = "foo"; + t.scalar()() = "foo"; EXPECT_EQ("foo", Tt()); } { Tensor t(DT_STRING, TensorShape({1})); EXPECT_EQ(1, t.NumElements()); - auto Tt = t.vec(); + auto Tt = t.vec(); EXPECT_EQ(1, Tt.size()); - t.flat()(0) = "foo"; + t.flat()(0) = "foo"; EXPECT_EQ("foo", Tt(0)); } { Tensor t(DT_STRING, TensorShape({1, 1, 1})); EXPECT_EQ(1, t.NumElements()); - auto Tt = t.scalar(); + auto Tt = t.scalar(); EXPECT_EQ(1, Tt.size()); EXPECT_EQ(0, Tt.rank()); - t.flat()(0) = "bar"; + t.flat()(0) = "bar"; EXPECT_EQ("bar", Tt()); } { @@ -860,7 +860,7 @@ TEST(Tensor_HostScalar, Basics) { Tensor t("fooooooooooooooooooooooooooooooooooooo"); EXPECT_EQ(DT_STRING, t.dtype()); EXPECT_EQ(1, t.NumElements()); - auto Tt = t.scalar(); + auto Tt = t.scalar(); EXPECT_EQ(1, Tt.size()); EXPECT_EQ(0, Tt.rank()); EXPECT_EQ("fooooooooooooooooooooooooooooooooooooo", Tt()); @@ -980,7 +980,7 @@ TEST(Tensor_String, SimpleWithHelper) { Tensor t2(DT_STRING, {2, 3}); for (int i = 0; i < 2; ++i) { for (int j = 0; j < 3; ++j) { - t2.matrix()(i, j) = strings::StrCat(i * 3 + j); + t2.matrix()(i, j) = strings::StrCat(i * 3 + j); } } @@ -1163,7 +1163,7 @@ TEST(Tensor, FailureToAllocate) { // String { Tensor t(DT_STRING, TensorShape({1})); - t.vec()(0) = "foo"; + t.vec()(0) = "foo"; TensorProto proto; t.AsProtoField(&proto); diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc index c87cc9548da..2e99626cb94 100644 --- a/tensorflow/core/framework/tensor_util.cc +++ b/tensorflow/core/framework/tensor_util.cc @@ -48,7 +48,7 @@ void DeepCopy(const Tensor& input, Tensor* output) { input_data.size()); } } else if (input.dtype() == DT_STRING) { - output->unaligned_flat() = input.unaligned_flat(); + output->unaligned_flat() = input.unaligned_flat(); } else { CHECK_EQ(DT_VARIANT, input.dtype()); output->unaligned_flat() = input.unaligned_flat(); @@ -103,7 +103,7 @@ Status Concat(const gtl::ArraySlice& tensors, Tensor* result) { int64 offset = 0; for (const Tensor& tensor : tensors) { - auto from_strings = tensor.flat(); + auto from_strings = tensor.flat(); CHECK_LE(offset + tensor.NumElements(), result->NumElements()); for (int i = 0; i < tensor.NumElements(); ++i) { to_strings[offset + i] = from_strings(i); @@ -155,7 +155,7 @@ Status Split(const Tensor& tensor, const gtl::ArraySlice& sizes, if (tensor.dtype() != DT_STRING) { return errors::Internal("Unexpected data type"); } - auto from_strings = tensor.flat(); + auto from_strings = tensor.flat(); int64 offset = 0; for (int64 size : sizes) { diff --git a/tensorflow/core/framework/tensor_util_test.cc b/tensorflow/core/framework/tensor_util_test.cc index 44708765bbf..fe988015e27 100644 --- a/tensorflow/core/framework/tensor_util_test.cc +++ b/tensorflow/core/framework/tensor_util_test.cc @@ -111,12 +111,12 @@ TEST(TensorUtil, DeepCopy) { // Test string deep copy Tensor str1(DT_STRING, TensorShape({2})); - str1.flat()(0) = "foo1"; - str1.flat()(1) = "foo2"; + str1.flat()(0) = "foo1"; + str1.flat()(1) = "foo2"; Tensor str2 = tensor::DeepCopy(str1); - str2.flat()(0) = "bar1"; - str2.flat()(1) = "bar2"; - EXPECT_NE(str2.flat()(0), str1.flat()(0)); + str2.flat()(0) = "bar1"; + str2.flat()(1) = "bar2"; + EXPECT_NE(str2.flat()(0), str1.flat()(0)); } TEST(TensorUtil, DeepCopySlice) { @@ -151,7 +151,7 @@ TEST(TensorUtil, DeepCopySlice) { TEST(TensorUtil, DeepCopySliceString) { Tensor x(DT_STRING, TensorShape({10})); - x.flat().setConstant("hello"); + x.flat().setConstant("hello"); // Slice 'x' -- y still refers to the same buffer. Tensor y = x.Slice(3, 7); @@ -160,7 +160,7 @@ TEST(TensorUtil, DeepCopySliceString) { Tensor z = tensor::DeepCopy(y); // Set x to be different. - x.flat().setConstant("goodbye"); + x.flat().setConstant("goodbye"); EXPECT_EQ(TensorShape({10}), x.shape()); EXPECT_EQ(TensorShape({4}), y.shape()); @@ -171,11 +171,11 @@ TEST(TensorUtil, DeepCopySliceString) { // x and y should now all be 'goodbye', but z should be 'hello'. for (int i = 0; i < 10; ++i) { - EXPECT_EQ("goodbye", x.flat()(i)); + EXPECT_EQ("goodbye", x.flat()(i)); } for (int i = 0; i < 4; ++i) { - EXPECT_EQ("goodbye", y.unaligned_flat()(i)); - EXPECT_EQ("hello", z.flat()(i)); + EXPECT_EQ("goodbye", y.unaligned_flat()(i)); + EXPECT_EQ("hello", z.flat()(i)); } } @@ -202,11 +202,12 @@ TEST(TensorUtil, DeepCopySliceVariant) { // Each element of x and y should now be a DT_STRING Tensor containing "foo", // but each element of z should be a DT_FLOAT tensor containing 42.0. for (int i = 0; i < 10; ++i) { - EXPECT_EQ("foo", x.flat()(i).get()->scalar()()); + EXPECT_EQ("foo", x.flat()(i).get()->scalar()()); } for (int i = 0; i < 4; ++i) { - EXPECT_EQ("foo", - y.unaligned_flat()(i).get()->scalar()()); + EXPECT_EQ( + "foo", + y.unaligned_flat()(i).get()->scalar()()); EXPECT_EQ(42.0, z.flat()(i).get()->scalar()()); } } @@ -271,7 +272,7 @@ TEST(TensorUtil, Split) { TEST(TensorUtil, ConcatSplitStrings) { Tensor x(DT_STRING, TensorShape({4, 3})); for (int i = 0; i < 4 * 3; ++i) { - x.flat()(i) = strings::StrCat("foo_", i); + x.flat()(i) = strings::StrCat("foo_", i); } std::vector split; @@ -280,15 +281,15 @@ TEST(TensorUtil, ConcatSplitStrings) { TF_ASSERT_OK(tensor::Concat(split, &x_round_tripped)); ASSERT_EQ(x.shape(), x_round_tripped.shape()); for (int i = 0; i < 4 * 3; ++i) { - EXPECT_EQ(x.flat()(i), x_round_tripped.flat()(i)); + EXPECT_EQ(x.flat()(i), x_round_tripped.flat()(i)); } // Ensure that no memory is being shared between 'x' and 'x_round_tripped'. for (int i = 0; i < 4 * 3; ++i) { - x_round_tripped.flat()(i) = strings::StrCat("bar_", i); + x_round_tripped.flat()(i) = strings::StrCat("bar_", i); } for (int i = 0; i < 4 * 3; ++i) { - EXPECT_NE(x.flat()(i), x_round_tripped.flat()(i)); + EXPECT_NE(x.flat()(i), x_round_tripped.flat()(i)); } } diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc index 25cddc00a3a..19226d232ae 100644 --- a/tensorflow/core/framework/variant_op_copy_test.cc +++ b/tensorflow/core/framework/variant_op_copy_test.cc @@ -244,7 +244,7 @@ TEST(VariantOpCopyTest, CreateConstOnGPUFailsGracefully) { // Create the input StoredTensorValue and serialize it. StoredTensorValue from; from.stored = Tensor(DT_STRING, TensorShape({})); - from.stored.scalar()() = "hi"; + from.stored.scalar()() = "hi"; VariantTensorData data; data.set_type_name(from.TypeName()); from.Encode(&data); @@ -292,7 +292,7 @@ TEST(VariantOpCopyTest, CreateCopyCPUToCPU) { TEST(VariantOpCopyTest, CreateCopyCPUToCPUString) { Scope root = Scope::NewRootScope().WithDevice("/cpu:0"); Tensor t_str(DT_STRING, TensorShape({})); - t_str.scalar()() = "hi"; + t_str.scalar()() = "hi"; Output create_op = CreateTestVariant(root, t_str); Output identity = ops::Identity(root, create_op); @@ -309,7 +309,7 @@ TEST(VariantOpCopyTest, CreateCopyCPUToCPUString) { EXPECT_EQ("StoredTensorValue", r1.TypeName()); const StoredTensorValue* v1 = r1.get(); EXPECT_NE(v1, nullptr); - EXPECT_EQ("hi", v1->stored.scalar()()); + EXPECT_EQ("hi", v1->stored.scalar()()); } } @@ -356,7 +356,7 @@ TEST(VariantOpCopyTest, CreateCopyCPUToGPUStringFailsSafely) { Scope root = Scope::NewRootScope().WithDevice("/cpu:0"); Scope with_gpu = root.WithDevice("/gpu:0"); Tensor t_str(DT_STRING, TensorShape({})); - t_str.scalar()() = "hi"; + t_str.scalar()() = "hi"; Output create_op = CreateTestVariant(root, t_str); Output identity = ops::Identity(with_gpu, create_op); diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc index 26bb6543569..4670e7a543c 100644 --- a/tensorflow/core/graph/quantize_training.cc +++ b/tensorflow/core/graph/quantize_training.cc @@ -172,8 +172,8 @@ StringPiece GetNodeNamePrefix(const Node* node) { } void FillStringTensor(Tensor* dst, const Tensor& src) { - auto dst_flat = dst->flat(); - auto src_flat = src.flat(); + auto dst_flat = dst->flat(); + auto src_flat = src.flat(); for (int i = 0; i < src.NumElements(); i++) { dst_flat(i) = src_flat(i); } @@ -220,8 +220,8 @@ Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op, FillStringTensor(&new_shape_and_slices, shape_and_slices); for (int i = 0; i < var_size; i++) { Node* var = added_variables[i]; - new_tensor_names.flat()(tn_size + i) = var->name(); - new_shape_and_slices.flat()(tn_size + i) = ""; + new_tensor_names.flat()(tn_size + i) = var->name(); + new_shape_and_slices.flat()(tn_size + i) = ""; var_nodeouts.emplace_back(var); } save_op_builder = save_op_builder.Input(var_nodeouts); @@ -275,7 +275,7 @@ Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op, // Construct the tensor_names input with the variable name. Node* tensor_names; Tensor tensor_names_val(DT_STRING, TensorShape({1})); - tensor_names_val.flat()(0) = var->name(); + tensor_names_val.flat()(0) = var->name(); TF_RETURN_IF_ERROR(NodeBuilder(tensor_names_op_name, "Const") .Attr("dtype", DT_STRING) .Attr("value", tensor_names_val) @@ -284,7 +284,7 @@ Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op, // Construct the shape_and_slices input with empty string. Node* shape_and_slices; Tensor shape_and_slices_val(DT_STRING, TensorShape({1})); - shape_and_slices_val.flat()(0) = ""; + shape_and_slices_val.flat()(0) = ""; TF_RETURN_IF_ERROR(NodeBuilder(shape_and_slices_op_name, "Const") .Attr("dtype", DT_STRING) .Attr("value", shape_and_slices_val) diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index d45bb14e070..198b6039b66 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -130,7 +130,7 @@ static void ExtractExtraProperties( if (tensor.NumElements() != 1) { continue; } - const string filename = tensor.scalar()(); + const string filename = tensor.scalar()(); Env* env = Env::Default(); FileStatistics stat; diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc index 5b3e140f23d..7be98dc43b4 100644 --- a/tensorflow/core/grappler/graph_view_test.cc +++ b/tensorflow/core/grappler/graph_view_test.cc @@ -98,7 +98,7 @@ TEST_F(GraphViewTest, OpPortIdToArgIdSparseSplit) { TEST_F(GraphViewTest, ParseSingleExample) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output a = ops::Const(s.WithOpName("a"), "", {}); + Output a = ops::Const(s.WithOpName("a"), "", {}); Output b = ops::Const(s.WithOpName("b"), 1, {1, 1}); ops::ParseSingleExample c(s.WithOpName("c"), a, {b, b}, 2, {"w", "x"}, {"y", "z"}, {DT_INT64, DT_INT64}, {{1}, {1}}); diff --git a/tensorflow/core/kernels/as_string_op.cc b/tensorflow/core/kernels/as_string_op.cc index e6d6c40f760..8341909fbc8 100644 --- a/tensorflow/core/kernels/as_string_op.cc +++ b/tensorflow/core/kernels/as_string_op.cc @@ -116,7 +116,7 @@ class AsStringOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output("output", input_tensor->shape(), &output_tensor)); - auto output_flat = output_tensor->flat(); + auto output_flat = output_tensor->flat(); #define ENCODE_TYPE(type, T, enc_str) \ case (type): { \ diff --git a/tensorflow/core/kernels/barrier_ops.cc b/tensorflow/core/kernels/barrier_ops.cc index 89d742c2daf..adbe370395c 100644 --- a/tensorflow/core/kernels/barrier_ops.cc +++ b/tensorflow/core/kernels/barrier_ops.cc @@ -308,7 +308,7 @@ class Barrier : public ResourceBase { int component_index, int i, std::vector* ready_tuples, bool* new_elements) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - auto keys_vec = keys.flat(); + auto keys_vec = keys.flat(); auto values_matrix = values.flat_outer_dims(); PersistentTuple* element_ptr; @@ -392,7 +392,7 @@ class Barrier : public ResourceBase { &key, &allocated_key)); ready_tuple.push_back(*element[0].AccessTensor(ctx)); // index ready_tuple.push_back(*allocated_key); // key - ready_tuple[1].scalar()() = keys_vec(i); // set the key + ready_tuple[1].scalar()() = keys_vec(i); // set the key for (int j = 1; j < num_components() + 1; ++j) { ready_tuple.push_back(*element[j].AccessTensor(ctx)); } diff --git a/tensorflow/core/kernels/base64_ops.cc b/tensorflow/core/kernels/base64_ops.cc index 74e6b39390a..cb235f56615 100644 --- a/tensorflow/core/kernels/base64_ops.cc +++ b/tensorflow/core/kernels/base64_ops.cc @@ -36,8 +36,8 @@ class EncodeBase64Op : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); - auto input = input_tensor.flat(); - auto output = output_tensor->flat(); + auto input = input_tensor.flat(); + auto output = output_tensor->flat(); for (int64 i = 0; i < input.dimension(0); ++i) { OP_REQUIRES_OK(context, Base64Encode(input(i), pad_, &output(i))); @@ -61,8 +61,8 @@ class DecodeBase64Op : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); - auto input = input_tensor.flat(); - auto output = output_tensor->flat(); + auto input = input_tensor.flat(); + auto output = output_tensor->flat(); for (int64 i = 0; i < input.dimension(0); ++i) { OP_REQUIRES_OK(context, Base64Decode(input(i), &output(i))); diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc index 718cf8e4139..7cd62af3a95 100644 --- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc @@ -324,7 +324,7 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel { context, context->allocate_output("examples_debug_outputs_serialized", {batch_size}, &output_debug_info_t)); // Will contain serialized protos, per example. - auto output_debug_info = output_debug_info_t->flat(); + auto output_debug_info = output_debug_info_t->flat(); const int32 last_tree = resource->num_trees() - 1; // For each given example, traverse through all trees keeping track of the diff --git a/tensorflow/core/kernels/boosted_trees/resource_ops.cc b/tensorflow/core/kernels/boosted_trees/resource_ops.cc index 5a9c3549041..ac1fb5652da 100644 --- a/tensorflow/core/kernels/boosted_trees/resource_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/resource_ops.cc @@ -51,7 +51,7 @@ class BoostedTreesCreateEnsembleOp : public OpKernel { std::unique_ptr result( new BoostedTreesEnsembleResource()); if (!result->InitFromSerialized( - tree_ensemble_serialized_t->scalar()(), stamp_token)) { + tree_ensemble_serialized_t->scalar()(), stamp_token)) { result->Unref(); OP_REQUIRES( context, false, @@ -152,7 +152,7 @@ class BoostedTreesSerializeEnsembleOp : public OpKernel { Tensor* output_proto_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape(), &output_proto_t)); - output_proto_t->scalar()() = + output_proto_t->scalar()() = tree_ensemble_resource->SerializeAsString(); } }; @@ -187,7 +187,7 @@ class BoostedTreesDeserializeEnsembleOp : public OpKernel { OP_REQUIRES( context, tree_ensemble_resource->InitFromSerialized( - tree_ensemble_serialized_t->scalar()(), stamp_token), + tree_ensemble_serialized_t->scalar()(), stamp_token), errors::InvalidArgument("Unable to parse tree ensemble proto.")); } }; diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc index de9e378e704..c421bff44ca 100644 --- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc @@ -393,7 +393,7 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel { OP_REQUIRES_OK( context, context->allocate_output("split_with_default_directions", {num_nodes}, &output_split_types_t)); - auto output_split_types_vec = output_split_types_t->vec(); + auto output_split_types_vec = output_split_types_t->vec(); // Sets output tensors from vectors. for (int i = 0; i < num_nodes; ++i) { @@ -677,7 +677,7 @@ class BoostedTreesSparseCalculateBestFeatureSplitOp : public OpKernel { OP_REQUIRES_OK( context, context->allocate_output("split_with_default_directions", {num_nodes}, &output_split_types_t)); - auto output_split_types_vec = output_split_types_t->vec(); + auto output_split_types_vec = output_split_types_t->vec(); // Sets output tensors from vectors. for (int i = 0; i < num_nodes; ++i) { diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h index ab54fc1d914..a2bfa2cdc8c 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base_op.h +++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h @@ -113,7 +113,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel { // Verify that the shared accumulator is compatible // with the requested arguments. TF_RETURN_IF_ERROR(accumulator->MatchesNodeDef(def())); - auto h = accumulator_handle_.AccessTensor(ctx)->template flat(); + auto h = accumulator_handle_.AccessTensor(ctx)->template flat(); h(0) = cinfo_.container(); h(1) = cinfo_.name(); accumulator_handle_set_ = true; diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc index 2bbd0ec35fb..3c7fbe0c65a 100644 --- a/tensorflow/core/kernels/conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/conditional_accumulator_op.cc @@ -85,7 +85,7 @@ class ResourceConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { void SetHandleToOutput(OpKernelContext* ctx) SHARED_LOCKS_REQUIRED(mu_) override { - auto h = accumulator_handle_.AccessTensor(ctx)->template flat(); + auto h = accumulator_handle_.AccessTensor(ctx)->template flat(); h(0) = cinfo_.container(); h(1) = cinfo_.name(); OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc index 58cd17482e7..e931755d36e 100644 --- a/tensorflow/core/kernels/data/dataset_ops.cc +++ b/tensorflow/core/kernels/data/dataset_ops.cc @@ -40,7 +40,7 @@ void DatasetToGraphOp::Compute(OpKernelContext* ctx) { ctx, AsGraphDef(ctx, dataset, SerializationContext({}), &graph_def)); Tensor* result; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result)); - result->scalar()() = graph_def.SerializeAsString(); + result->scalar()() = graph_def.SerializeAsString(); } void DatasetCardinalityOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc index a2d9bd8d062..682085262fb 100644 --- a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc @@ -93,7 +93,7 @@ class CSVDatasetOp : public DatasetOpKernel { std::vector filenames; filenames.reserve(filenames_tensor->NumElements()); for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.push_back(filenames_tensor->flat()(i)); + filenames.push_back(filenames_tensor->flat()(i)); } io::ZlibCompressionOptions zlib_compression_options = @@ -719,10 +719,10 @@ class CSVDatasetOp : public DatasetOpKernel { } case DT_STRING: { if (field.empty() || field == dataset()->na_value_) { - component.scalar()() = - dataset()->record_defaults_[output_idx].flat()(0); + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); } else { - component.scalar()() = string(field); + component.scalar()() = string(field); } break; } diff --git a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc index f587fe9e4c7..d19085fc35c 100644 --- a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc @@ -38,7 +38,7 @@ class LMDBDatasetOp : public DatasetOpKernel { std::vector filenames; filenames.reserve(filenames_tensor->NumElements()); for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.push_back(filenames_tensor->flat()(i)); + filenames.push_back(filenames_tensor->flat()(i)); } *output = new Dataset(ctx, filenames); @@ -95,13 +95,13 @@ class LMDBDatasetOp : public DatasetOpKernel { out_tensors->emplace_back(ctx->allocator({}), DT_STRING, TensorShape({})); Tensor& key_tensor = out_tensors->back(); - key_tensor.scalar()() = string( + key_tensor.scalar()() = string( static_cast(mdb_key_.mv_data), mdb_key_.mv_size); out_tensors->emplace_back(ctx->allocator({}), DT_STRING, TensorShape({})); Tensor& value_tensor = out_tensors->back(); - value_tensor.scalar()() = + value_tensor.scalar()() = string(static_cast(mdb_value_.mv_data), mdb_value_.mv_size); diff --git a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc index fd4beb03e57..0ae425556bc 100644 --- a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc @@ -42,7 +42,7 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { const Tensor* patterns_t; OP_REQUIRES_OK(ctx, ctx->input("patterns", &patterns_t)); - const auto patterns = patterns_t->flat(); + const auto patterns = patterns_t->flat(); size_t num_patterns = static_cast(patterns.size()); std::vector pattern_strs; pattern_strs.reserve(num_patterns); @@ -126,7 +126,7 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { current_path.first.end(), '/', '\\'); } - filepath_tensor.scalar()() = + filepath_tensor.scalar()() = std::move(current_path.first); out_tensors->emplace_back(std::move(filepath_tensor)); *end_of_sequence = false; diff --git a/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc index 8a4089b580b..3fbb9bd79b9 100644 --- a/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc +++ b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc @@ -42,7 +42,7 @@ class IteratorGetDeviceOp : public OpKernel { // NOTE(mrry): Since the operation's input is a resource, we must be // colocated with it, and so we can simply return the current device's // name without looking at the input. - device_name_t->scalar()() = ctx->device()->name(); + device_name_t->scalar()() = ctx->device()->name(); } }; diff --git a/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc index a2a1330e29f..05dadf084d4 100644 --- a/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc @@ -267,7 +267,7 @@ class StatsAggregatorSummaryOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t)); Summary summary; resource->stats_aggregator()->EncodeToProto(&summary); - summary_t->scalar()() = summary.SerializeAsString(); + summary_t->scalar()() = summary.SerializeAsString(); } }; diff --git a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc index 5879750bf18..f45b493d851 100644 --- a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc +++ b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc @@ -122,7 +122,7 @@ class ToTFRecordOp : public AsyncOpKernel { if (!end_of_sequence) { OP_REQUIRES_OK_ASYNC( - ctx, writer->WriteRecord(components[0].scalar()()), + ctx, writer->WriteRecord(components[0].scalar()()), done); } components.clear(); diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc index 3ef920107cd..64f728d58c1 100644 --- a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc @@ -171,7 +171,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { return Hash64(t.tensor_data().data(), t.tensor_data().size()); } else { DCHECK_EQ(DT_STRING, t.dtype()); - auto flat_t = t.flat(); + auto flat_t = t.flat(); uint64 hash = 0; for (int64 i = 0; i < t.NumElements(); ++i) { hash = Hash64Combine(hash, Hash64(flat_t(i))); diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc index dd147c6fd95..976288f4a65 100644 --- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc @@ -141,7 +141,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { // Produce the record as output. Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); - record_tensor.scalar()() = record; + record_tensor.scalar()() = record; out_tensors->emplace_back(std::move(record_tensor)); *end_of_sequence = false; return Status::OK(); @@ -264,7 +264,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { // Produce the record as output. Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); - record_tensor.scalar()() = std::move(record); + record_tensor.scalar()() = std::move(record); out_tensors->emplace_back(std::move(record_tensor)); *end_of_sequence = false; return Status::OK(); @@ -282,7 +282,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { lookahead_cache_.substr(dataset()->record_bytes_); // Produce the record as output. Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); - record_tensor.scalar()() = std::move(record); + record_tensor.scalar()() = std::move(record); out_tensors->emplace_back(std::move(record_tensor)); *end_of_sequence = false; return Status::OK(); @@ -459,7 +459,7 @@ void FixedLengthRecordDatasetOp::MakeDataset(OpKernelContext* ctx, std::vector filenames; filenames.reserve(filenames_tensor->NumElements()); for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.push_back(filenames_tensor->flat()(i)); + filenames.push_back(filenames_tensor->flat()(i)); } int64 header_bytes = -1; diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 4965c2ee09b..1acf4f4c1bd 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -1002,7 +1002,7 @@ void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) { Tensor* string_handle_t; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &string_handle_t)); - string_handle_t->scalar()() = + string_handle_t->scalar()() = resource_handle_t.scalar()().SerializeAsString(); } @@ -1026,7 +1026,7 @@ void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) { ResourceHandle resource_handle; OP_REQUIRES( - ctx, resource_handle.ParseFromString(string_handle_t.scalar()()), + ctx, resource_handle.ParseFromString(string_handle_t.scalar()()), errors::InvalidArgument( "Could not parse string_handle as a valid ResourceHandle")); diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc index 409a50371f0..7a538d77d1b 100644 --- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc +++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc @@ -644,7 +644,7 @@ class MultiDeviceIteratorToStringHandleOp : public OpKernel { Tensor* string_handle_t; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &string_handle_t)); - string_handle_t->scalar()() = + string_handle_t->scalar()() = resource_handle_t.scalar()().SerializeAsString(); } }; @@ -675,7 +675,7 @@ class MultiDeviceIteratorFromStringHandleOp : public OpKernel { ResourceHandle resource_handle; OP_REQUIRES( ctx, - resource_handle.ParseFromString(string_handle_t.scalar()()), + resource_handle.ParseFromString(string_handle_t.scalar()()), errors::InvalidArgument( "Could not parse string_handle as a valid ResourceHandle")); diff --git a/tensorflow/core/kernels/data/text_line_dataset_op.cc b/tensorflow/core/kernels/data/text_line_dataset_op.cc index b8302b890c8..a8ebf631c94 100644 --- a/tensorflow/core/kernels/data/text_line_dataset_op.cc +++ b/tensorflow/core/kernels/data/text_line_dataset_op.cc @@ -108,7 +108,7 @@ class TextLineDatasetOp::Dataset : public DatasetBase { line_contents.size()); out_tensors->emplace_back(ctx->allocator({}), DT_STRING, TensorShape({})); - out_tensors->back().scalar()() = std::move(line_contents); + out_tensors->back().scalar()() = std::move(line_contents); *end_of_sequence = false; return Status::OK(); } else if (!errors::IsOutOfRange(s)) { @@ -266,7 +266,7 @@ void TextLineDatasetOp::MakeDataset(OpKernelContext* ctx, std::vector filenames; filenames.reserve(filenames_tensor->NumElements()); for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.push_back(filenames_tensor->flat()(i)); + filenames.push_back(filenames_tensor->flat()(i)); } *output = new Dataset(ctx, std::move(filenames), compression_type, diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op.cc b/tensorflow/core/kernels/data/tf_record_dataset_op.cc index e35743dae60..2b26f61bf7d 100644 --- a/tensorflow/core/kernels/data/tf_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/tf_record_dataset_op.cc @@ -108,7 +108,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { reader_->ReadRecord(&out_tensors->back().scalar()()); if (s.ok()) { metrics::RecordTFDataBytesRead( - kDatasetType, out_tensors->back().scalar()().size()); + kDatasetType, out_tensors->back().scalar()().size()); *end_of_sequence = false; return Status::OK(); } @@ -224,8 +224,8 @@ void TFRecordDatasetOp::MakeDataset(OpKernelContext* ctx, std::vector filenames; filenames.reserve(filenames_tensor->NumElements()); for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - VLOG(2) << "Reading file: " << filenames_tensor->flat()(i); - filenames.push_back(filenames_tensor->flat()(i)); + VLOG(2) << "Reading file: " << filenames_tensor->flat()(i); + filenames.push_back(filenames_tensor->flat()(i)); } string compression_type; diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc index 8a9f7b18601..122b7ecb3da 100644 --- a/tensorflow/core/kernels/decode_bmp_op.cc +++ b/tensorflow/core/kernels/decode_bmp_op.cc @@ -54,7 +54,7 @@ class DecodeBmpOp : public OpKernel { contents.shape().DebugString())); // Start decoding image to get shape details - const StringPiece input = contents.scalar()(); + const StringPiece input = contents.scalar()(); OP_REQUIRES(context, (32 <= input.size()), errors::InvalidArgument("Incomplete bmp content, requires at " diff --git a/tensorflow/core/kernels/decode_compressed_op.cc b/tensorflow/core/kernels/decode_compressed_op.cc index 3c3d49e1f8f..78376cea569 100644 --- a/tensorflow/core/kernels/decode_compressed_op.cc +++ b/tensorflow/core/kernels/decode_compressed_op.cc @@ -84,13 +84,13 @@ class DecodeCompressedOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor* bytes_tensor; OP_REQUIRES_OK(context, context->input("bytes", &bytes_tensor)); - const auto& bytes_flat = bytes_tensor->flat(); + const auto& bytes_flat = bytes_tensor->flat(); Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output("output", bytes_tensor->shape(), &output_tensor)); - auto output_flat = output_tensor->flat(); + auto output_flat = output_tensor->flat(); if (compression_type_.empty()) { for (int64 i = 0; i < bytes_flat.size(); i++) { output_flat(i) = bytes_flat(i); diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc index ba6369533ad..470a7b3859a 100644 --- a/tensorflow/core/kernels/decode_csv_op.cc +++ b/tensorflow/core/kernels/decode_csv_op.cc @@ -70,7 +70,7 @@ class DecodeCSVOp : public OpKernel { " has ", record_defaults[i].NumElements())); } - auto records_t = records->flat(); + auto records_t = records->flat(); int64 records_size = records_t.size(); OpOutputList output; @@ -181,10 +181,10 @@ class DecodeCSVOp : public OpKernel { errors::InvalidArgument( "Field ", f, " is required but missing in record ", i, "!")); - output[f]->flat()(i) = - record_defaults[f].flat()(0); + output[f]->flat()(i) = + record_defaults[f].flat()(0); } else { - output[f]->flat()(i) = fields[f]; + output[f]->flat()(i) = fields[f]; } break; } diff --git a/tensorflow/core/kernels/decode_image_op.cc b/tensorflow/core/kernels/decode_image_op.cc index 052c9f24e4b..f89533d1574 100644 --- a/tensorflow/core/kernels/decode_image_op.cc +++ b/tensorflow/core/kernels/decode_image_op.cc @@ -154,7 +154,7 @@ class DecodeImageOp : public OpKernel { contents.shape().DebugString())); // Determine format - const StringPiece input = contents.scalar()(); + const StringPiece input = contents.scalar()(); const auto magic = ClassifyFileFormat(input); OP_REQUIRES( context, diff --git a/tensorflow/core/kernels/decode_padded_raw_op.cc b/tensorflow/core/kernels/decode_padded_raw_op.cc index 1e6a0cb7606..12e8ec6aff0 100644 --- a/tensorflow/core/kernels/decode_padded_raw_op.cc +++ b/tensorflow/core/kernels/decode_padded_raw_op.cc @@ -39,7 +39,7 @@ class DecodePaddedRawOp : public OpKernel { void Compute(OpKernelContext* context) override { const auto& input = context->input(0); - auto flat_in = input.flat(); + auto flat_in = input.flat(); int fixed_length; const auto& length_input = context->input(1); diff --git a/tensorflow/core/kernels/decode_proto_op.cc b/tensorflow/core/kernels/decode_proto_op.cc index 06dc766794c..5717fa53169 100644 --- a/tensorflow/core/kernels/decode_proto_op.cc +++ b/tensorflow/core/kernels/decode_proto_op.cc @@ -748,14 +748,14 @@ class DecodeProtoOp : public OpKernel { if (is_binary_ && !sanitize_) { // Fast path. for (int mi = 0; mi < message_count; ++mi) { - const string* buf = &buf_tensor.flat()(mi); + const tstring* buf = &buf_tensor.flat()(mi); bufs.push_back(buf); } } else { // We will have to allocate a copy, either to convert from text to binary // or to sanitize a binary proto. for (int mi = 0; mi < message_count; ++mi) { - ReserializeMessage(ctx, buf_tensor.flat()(mi), + ReserializeMessage(ctx, buf_tensor.flat()(mi), &tmp_binary_bufs[mi]); if (!ctx->status().ok()) { return; @@ -895,8 +895,8 @@ class DecodeProtoOp : public OpKernel { data = tensor->bit_casted_shaped(flatshape).data(); } else { // DataTypeSize() returns 0 for string types. - stride = last_dim_size * sizeof(string); - data = reinterpret_cast(tensor->flat().data()); + stride = last_dim_size * sizeof(tstring); + data = reinterpret_cast(tensor->flat().data()); } } diff --git a/tensorflow/core/kernels/decode_raw_op.cc b/tensorflow/core/kernels/decode_raw_op.cc index e68fa407534..942589608c0 100644 --- a/tensorflow/core/kernels/decode_raw_op.cc +++ b/tensorflow/core/kernels/decode_raw_op.cc @@ -41,7 +41,7 @@ class DecodeRawOp : public OpKernel { void Compute(OpKernelContext* context) override { const auto& input = context->input(0); int64 str_size = -1; - auto flat_in = input.flat(); + auto flat_in = input.flat(); for (int64 i = 0; i < flat_in.size(); ++i) { const string& in_str = flat_in(i); if (str_size == -1) { diff --git a/tensorflow/core/kernels/decode_wav_op.cc b/tensorflow/core/kernels/decode_wav_op.cc index 4bd5d7ac2a6..6325c28b13e 100644 --- a/tensorflow/core/kernels/decode_wav_op.cc +++ b/tensorflow/core/kernels/decode_wav_op.cc @@ -40,7 +40,7 @@ class DecodeWavOp : public OpKernel { OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()), errors::InvalidArgument("contents must be scalar, got shape ", contents.shape().DebugString())); - const string wav_string = contents.scalar()(); + const string wav_string = contents.scalar()(); OP_REQUIRES(context, wav_string.size() <= std::numeric_limits::max(), errors::InvalidArgument("WAV contents are too large for int: ", wav_string.size())); diff --git a/tensorflow/core/kernels/deserialize_sparse_string_op.cc b/tensorflow/core/kernels/deserialize_sparse_string_op.cc index d26d8188d51..398df428994 100644 --- a/tensorflow/core/kernels/deserialize_sparse_string_op.cc +++ b/tensorflow/core/kernels/deserialize_sparse_string_op.cc @@ -75,7 +75,7 @@ class DeserializeSparseOp : public OpKernel { if (num_sparse_tensors == 1 && ndims == 1) { // Special case with a single sparse tensor. We can avoid data // motion in the Concat and Reshape. - const auto& serialized_sparse_t = serialized_sparse.vec(); + const auto& serialized_sparse_t = serialized_sparse.vec(); Tensor output_indices; Tensor output_values; @@ -98,7 +98,7 @@ class DeserializeSparseOp : public OpKernel { values.reserve(num_sparse_tensors); const auto& serialized_sparse_t = - serialized_sparse.flat_inner_dims(); + serialized_sparse.flat_inner_dims(); for (int i = 0; i < num_sparse_tensors; ++i) { Tensor output_indices; Tensor output_values; diff --git a/tensorflow/core/kernels/encode_proto_op.cc b/tensorflow/core/kernels/encode_proto_op.cc index b023f1cdeb8..12bbd34ec71 100644 --- a/tensorflow/core/kernels/encode_proto_op.cc +++ b/tensorflow/core/kernels/encode_proto_op.cc @@ -303,7 +303,7 @@ Status WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input, // code it ourselves. Status WriteGroup(const FieldDescriptor& field_desc, const Tensor& input, int message_index, int size, CodedOutputStream* output) { - auto input_t = input.flat_inner_dims(); + auto input_t = input.flat_inner_dims(); for (int64 i = 0; i < size; i++) { const string& value = input_t(static_cast(message_index), i); WireFormatLite::WriteTag(field_desc.number(), @@ -587,7 +587,7 @@ class EncodeProtoOp : public OpKernel { Tensor* output_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, common_prefix, &output_tensor)); - auto bufs = output_tensor->flat(); + auto bufs = output_tensor->flat(); for (int message_index = 0; message_index < message_count; message_index++) { // TODO(nix): possibly optimize allocation here by calling diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc index 708b52a5174..783190b50ef 100644 --- a/tensorflow/core/kernels/example_parsing_ops.cc +++ b/tensorflow/core/kernels/example_parsing_ops.cc @@ -63,10 +63,10 @@ class ParseExampleOp : public OpKernel { // Copy from OpInputList to std::vector. for (int di = 0; di < attrs_.num_dense; ++di) { - dense_keys_t[di] = dense_keys[di].scalar()(); + dense_keys_t[di] = dense_keys[di].scalar()(); } for (int di = 0; di < attrs_.num_sparse; ++di) { - sparse_keys_t[di] = sparse_keys[di].scalar()(); + sparse_keys_t[di] = sparse_keys[di].scalar()(); } if (names->NumElements() > 0) { @@ -234,7 +234,7 @@ class ParseSingleExampleOp : public OpKernel { config.sparse.push_back({attrs_.sparse_keys[d], attrs_.sparse_types[d]}); } - const string& serialized_proto = serialized->scalar()(); + const string& serialized_proto = serialized->scalar()(); OP_REQUIRES_OK(ctx, FastParseSingleExample(config, serialized_proto, &result)); @@ -473,7 +473,7 @@ class ParseSingleSequenceExampleOp : public OpKernel { "Expected context_dense_keys[", di, "] to be a scalar, got shape: ", context_dense_keys[di].shape().DebugString())); - context_dense_keys_t[di] = context_dense_keys[di].scalar()(); + context_dense_keys_t[di] = context_dense_keys[di].scalar()(); } for (int di = 0; di < attrs_.num_context_sparse; ++di) { OP_REQUIRES(ctx, @@ -482,7 +482,7 @@ class ParseSingleSequenceExampleOp : public OpKernel { "Expected context_sparse_keys[", di, "] to be a scalar, got shape: ", context_sparse_keys[di].shape().DebugString())); - context_sparse_keys_t[di] = context_sparse_keys[di].scalar()(); + context_sparse_keys_t[di] = context_sparse_keys[di].scalar()(); } for (int di = 0; di < attrs_.num_feature_list_dense; ++di) { OP_REQUIRES( @@ -492,7 +492,7 @@ class ParseSingleSequenceExampleOp : public OpKernel { "] to be a scalar, got shape: ", feature_list_dense_keys[di].shape().DebugString())); feature_list_dense_keys_t[di] = - feature_list_dense_keys[di].scalar()(); + feature_list_dense_keys[di].scalar()(); } for (int di = 0; di < attrs_.num_feature_list_sparse; ++di) { OP_REQUIRES( @@ -502,7 +502,7 @@ class ParseSingleSequenceExampleOp : public OpKernel { "] to be a scalar, got shape: ", feature_list_sparse_keys[di].shape().DebugString())); feature_list_sparse_keys_t[di] = - feature_list_sparse_keys[di].scalar()(); + feature_list_sparse_keys[di].scalar()(); } OP_REQUIRES( ctx, @@ -513,7 +513,7 @@ class ParseSingleSequenceExampleOp : public OpKernel { "to be a vector, got shape: ", feature_list_dense_missing_assumed_empty->shape().DebugString())); auto feature_list_dense_missing_assumped_empty_t = - feature_list_dense_missing_assumed_empty->vec(); + feature_list_dense_missing_assumed_empty->vec(); for (int de = 0; de < feature_list_dense_missing_assumed_empty->NumElements(); ++de) { feature_list_dense_missing_assumed_empty_set.insert( @@ -527,7 +527,7 @@ class ParseSingleSequenceExampleOp : public OpKernel { "Expected debug_name to be a scalar, got shape: ", debug_name->shape().DebugString())); } - auto debug_name_t = debug_name->scalar(); + auto debug_name_t = debug_name->scalar(); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()), errors::InvalidArgument( @@ -561,7 +561,7 @@ class ParseSingleSequenceExampleOp : public OpKernel { } } - auto serialized_t = serialized->scalar(); + auto serialized_t = serialized->scalar(); OpOutputList context_sparse_indices; OpOutputList context_sparse_values; diff --git a/tensorflow/core/kernels/example_parsing_ops_test.cc b/tensorflow/core/kernels/example_parsing_ops_test.cc index 4d843ab02cc..db1672e70a0 100644 --- a/tensorflow/core/kernels/example_parsing_ops_test.cc +++ b/tensorflow/core/kernels/example_parsing_ops_test.cc @@ -114,7 +114,7 @@ struct ExampleStore { Example example; Filler fill; Tensor record_string(DT_STRING, TensorShape({batch_size})); - auto string_t = record_string.vec(); + auto string_t = record_string.vec(); example.Clear(); for (int b = 0; b < batch_size; ++b) { for (int k = 0; k < num_keys; ++k) { @@ -163,7 +163,7 @@ static Graph* ParseExample(int batch_size, int num_keys, int feature_size) { Options opt; for (int i = 0; i < num_keys; ++i) { Tensor key(DT_STRING, TensorShape()); - key.scalar()() = strings::Printf("feature_%d", i); + key.scalar()() = strings::Printf("feature_%d", i); switch (opt.benchmark_type) { case kDense: dense_keys.emplace_back(test::graph::Constant(g, key)); @@ -205,7 +205,7 @@ static Graph* ParseSingleExample(int num_keys, int feature_size) { Options::Store::GetSerializedExample()[std::make_tuple(1, num_keys, feature_size)]; Tensor serialized(DT_STRING, TensorShape()); - serialized.scalar()() = serialized_batch_1.vec()(0); + serialized.scalar()() = serialized_batch_1.vec()(0); std::vector sparse_keys; std::vector dense_keys; diff --git a/tensorflow/core/kernels/extract_jpeg_shape_op.cc b/tensorflow/core/kernels/extract_jpeg_shape_op.cc index ab424595c1a..c74245dcf85 100644 --- a/tensorflow/core/kernels/extract_jpeg_shape_op.cc +++ b/tensorflow/core/kernels/extract_jpeg_shape_op.cc @@ -41,7 +41,7 @@ class ExtractJpegShapeOp : public OpKernel { OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()), errors::InvalidArgument("contents must be scalar, got shape ", contents.shape().DebugString())); - const StringPiece input = contents.scalar()(); + const StringPiece input = contents.scalar()(); OP_REQUIRES(context, input.size() <= std::numeric_limits::max(), errors::InvalidArgument("JPEG contents are too large for int: ", input.size())); diff --git a/tensorflow/core/kernels/fact_op.cc b/tensorflow/core/kernels/fact_op.cc index 4a1aa433bc9..6c11ab7a2d2 100644 --- a/tensorflow/core/kernels/fact_op.cc +++ b/tensorflow/core/kernels/fact_op.cc @@ -85,7 +85,7 @@ class FactOpKernel : public OpKernel { Tensor* output_tensor = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, TensorShape({}), &output_tensor)); - auto output = output_tensor->template scalar(); + auto output = output_tensor->template scalar(); string coded = facts[context->env()->NowMicros() % count]; E(&coded); diff --git a/tensorflow/core/kernels/fingerprint_op.cc b/tensorflow/core/kernels/fingerprint_op.cc index 20529326b3d..660f900c405 100644 --- a/tensorflow/core/kernels/fingerprint_op.cc +++ b/tensorflow/core/kernels/fingerprint_op.cc @@ -110,14 +110,14 @@ class FingerprintOp : public OpKernel { // and each row contains the fingerprint value of corresponding string. // To compute fingerprints of multiple strings, this op fingerprints the // buffer containing the string fingerprints. - FarmhashFingerprint64(input.flat(), temp.tensor()); + FarmhashFingerprint64(input.flat(), temp.tensor()); FarmhashFingerprint64(static_cast(temp).shaped( {dim0, dim1 * kFingerprintSize}), output->matrix()); } else { // In case dim1 == 1, each string computes into its own fingerprint // value. There is no need to fingerprint twice. - FarmhashFingerprint64(input.flat(), output->matrix()); + FarmhashFingerprint64(input.flat(), output->matrix()); } } else { auto data = input.bit_casted_shaped( diff --git a/tensorflow/core/kernels/fingerprint_op_test.cc b/tensorflow/core/kernels/fingerprint_op_test.cc index 14376cb2d35..d9a9a97798d 100644 --- a/tensorflow/core/kernels/fingerprint_op_test.cc +++ b/tensorflow/core/kernels/fingerprint_op_test.cc @@ -51,7 +51,7 @@ class FingerprintOpTest : public OpsTestBase { inputs_.push_back(TensorValue(data)); method_ = Tensor(DT_STRING, TensorShape{}); - method_.scalar()() = method; + method_.scalar()() = method; inputs_.push_back(TensorValue(&method_)); return Status::OK(); } @@ -77,7 +77,7 @@ TEST_F(FingerprintOpTest, GoldenValue) { // special-case handling. TEST_F(FingerprintOpTest, StringGoldenValue) { Tensor data(DT_STRING, {1, 2, 2}); - auto buffer = data.flat(); + auto buffer = data.flat(); buffer(0).resize(10); buffer(1).resize(7); buffer(2).resize(0); @@ -134,7 +134,7 @@ TEST_F(FingerprintOpTest, CollisionString) { constexpr int64 size = 256; Tensor tensor(DT_STRING, {1}); - auto& input = tensor.vec()(0); + auto& input = tensor.vec()(0); input.resize(size); TTypes::UnalignedFlat buffer(reinterpret_cast(&*input.begin()), @@ -163,7 +163,7 @@ TEST_F(FingerprintOpTest, CompareBytesAndString) { auto pods = pods_tensor.matrix(); pods.setRandom(); - auto strings = strings_tensor.vec(); + auto strings = strings_tensor.vec(); for (int64 i = 0; i < strings.size(); ++i) { strings(i).assign(reinterpret_cast(&pods(i, 0)), pods.dimension(1) * sizeof(pods(i, 0))); @@ -199,7 +199,7 @@ TEST(FingerprintOpShapeFnTest, MethodKnownStatically) { ShapeInferenceTestOp op("Fingerprint"); Tensor method(DT_STRING, TensorShape{}); - method.scalar()() = "farmhash64"; + method.scalar()() = "farmhash64"; op.input_tensors.assign({nullptr, &method}); TF_ASSERT_OK(MakeNodeDef(DT_UINT8, &op.node_def)); @@ -229,12 +229,12 @@ TEST(FingerprintOpShapeFnTest, InvalidMethod) { // When `method` shape is unknown statically. Tensor method(DT_STRING, TensorShape{1}); - method.vec()(0) = "farmhash64"; + method.vec()(0) = "farmhash64"; op.input_tensors.assign({nullptr, &method}); INFER_ERROR("must be rank 0", op, "?;?"); method = Tensor(DT_STRING, TensorShape{}); - method.scalar()() = "unsupported_method"; + method.scalar()() = "unsupported_method"; op.input_tensors.assign({nullptr, &method}); INFER_ERROR("unsupported_method", op, "?;?"); } diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 33bed217003..087ff2ee847 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -318,7 +318,7 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { string target_device; OP_REQUIRES_OK_ASYNC( ctx, - DeviceNameUtils::CanonicalizeDeviceName(target->scalar()(), + DeviceNameUtils::CanonicalizeDeviceName(target->scalar()(), source_device, &target_device), done); diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index 920c14b36ac..d7d15d5f14b 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -82,7 +82,7 @@ Status ToBool(gtl::ArraySlice t, bool* v) { *v = t[0].scalar()(); break; case DT_STRING: - *v = !t[0].scalar()().empty(); + *v = !t[0].scalar()().empty(); break; default: return errors::InvalidArgument(DataTypeString(t[0].dtype()), diff --git a/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc b/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc index f72dfb39b31..35cd0fba1a6 100644 --- a/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc @@ -51,7 +51,7 @@ class FuzzExampleProtoFastParsing : public FuzzSession { void FuzzImpl(const uint8_t* data, size_t size) final { // TODO(dga): Test the batch case also. Tensor input_tensor(tensorflow::DT_STRING, TensorShape({})); - input_tensor.scalar()() = + input_tensor.scalar()() = string(reinterpret_cast(data), size); RunInputs({{"input", input_tensor}}); } diff --git a/tensorflow/core/kernels/fuzzing/fuzz_session.h b/tensorflow/core/kernels/fuzzing/fuzz_session.h index 4b036b181de..dc0435cdc32 100644 --- a/tensorflow/core/kernels/fuzzing/fuzz_session.h +++ b/tensorflow/core/kernels/fuzzing/fuzz_session.h @@ -145,7 +145,7 @@ class FuzzSession { class FuzzStringInputOp : public FuzzSession { void FuzzImpl(const uint8_t* data, size_t size) final { Tensor input_tensor(tensorflow::DT_STRING, TensorShape({})); - input_tensor.scalar()() = + input_tensor.scalar()() = string(reinterpret_cast(data), size); RunInputs({{"input", input_tensor}}); } diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc index 0ce4206fc3c..a71f2902559 100644 --- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc @@ -61,7 +61,7 @@ class FuzzParseTensor : public FuzzSession { // Now we can do the actual fuzz implementation Tensor input_tensor(tensorflow::DT_STRING, TensorShape({})); - input_tensor.scalar()() = as_string; + input_tensor.scalar()() = as_string; RunInputs({{"input", input_tensor}}); } }; diff --git a/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc b/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc index b3b637bac72..d4e64181ab2 100644 --- a/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc @@ -42,9 +42,9 @@ class FuzzStringSplit : public FuzzSession { if (delim_len > size) { delim_len = size - 1; } - delimiter_tensor.scalar()() = + delimiter_tensor.scalar()() = string(reinterpret_cast(data), delim_len); - input_tensor.scalar()() = string( + input_tensor.scalar()() = string( reinterpret_cast(data + delim_len), size - delim_len); RunInputs({{"input", input_tensor}, {"delimiter", delimiter_tensor}}); diff --git a/tensorflow/core/kernels/fuzzing/string_split_v2_fuzz.cc b/tensorflow/core/kernels/fuzzing/string_split_v2_fuzz.cc index f7e3da80437..367759d374e 100644 --- a/tensorflow/core/kernels/fuzzing/string_split_v2_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/string_split_v2_fuzz.cc @@ -46,10 +46,10 @@ class FuzzStringSplitV2 : public FuzzSession { if (sep_len > size) { sep_len = size - 1; } - separator_tensor.scalar()() = + separator_tensor.scalar()() = string(reinterpret_cast(data), sep_len); - input_tensor.scalar()() = string( - reinterpret_cast(data + sep_len), size - sep_len); + input_tensor.scalar()() = + string(reinterpret_cast(data + sep_len), size - sep_len); RunInputs({{"input", input_tensor}, {"separator", separator_tensor}}); } diff --git a/tensorflow/core/kernels/generate_vocab_remapping_op.cc b/tensorflow/core/kernels/generate_vocab_remapping_op.cc index 2b97677e385..03d9191423d 100644 --- a/tensorflow/core/kernels/generate_vocab_remapping_op.cc +++ b/tensorflow/core/kernels/generate_vocab_remapping_op.cc @@ -57,7 +57,7 @@ class GenerateVocabRemappingOp : public OpKernel { // Build a new ID->token lookup table. const string& new_vocab_filename = - new_vocab_file_tensor->scalar()(); + new_vocab_file_tensor->scalar()(); OP_REQUIRES(context, !new_vocab_filename.empty(), errors::InvalidArgument("new vocab filename cannot be empty.")); lookup::HashTable* new_vocab_table = @@ -88,7 +88,7 @@ class GenerateVocabRemappingOp : public OpKernel { old_vocab_file_tensor->shape().DebugString())); // Build a token->old ID lookup table. const string& old_vocab_filename = - old_vocab_file_tensor->scalar()(); + old_vocab_file_tensor->scalar()(); OP_REQUIRES(context, !old_vocab_filename.empty(), errors::InvalidArgument("new vocab filename cannot be empty.")); lookup::HashTable* old_vocab_table = @@ -118,7 +118,7 @@ class GenerateVocabRemappingOp : public OpKernel { OP_REQUIRES_OK( context, context->allocate_temp( DT_STRING, TensorShape({num_new_vocab_}), &default_token)); - auto default_token_vec = default_token.vec(); + auto default_token_vec = default_token.vec(); default_token_vec.setConstant("" /* NOT_FOUND_TOKEN */); Tensor default_id; diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc index c0d39d9d46d..a6f026150ea 100644 --- a/tensorflow/core/kernels/inplace_ops.cc +++ b/tensorflow/core/kernels/inplace_ops.cc @@ -319,8 +319,8 @@ void DoInplaceOp(const CPUDevice& d, InplaceOpType op, const Tensor& i, void DoInplaceStringUpdateOp(const CPUDevice& d, const Tensor& i, const Tensor& v, Tensor* y) { auto Ti = i.flat(); - auto Tv = v.flat_outer_dims(); - auto Ty = y->flat_outer_dims(); + auto Tv = v.flat_outer_dims(); + auto Ty = y->flat_outer_dims(); auto nrows = Ty.dimension(0); for (int64 j = 0; j < Ti.size(); ++j) { auto r = (Ti(j) % nrows + nrows) % nrows; // Guard index range. diff --git a/tensorflow/core/kernels/load_and_remap_matrix_op.cc b/tensorflow/core/kernels/load_and_remap_matrix_op.cc index 9d5a4b2f035..3b086517178 100644 --- a/tensorflow/core/kernels/load_and_remap_matrix_op.cc +++ b/tensorflow/core/kernels/load_and_remap_matrix_op.cc @@ -123,12 +123,12 @@ class LoadAndRemapMatrixOp : public OpKernel { // Processes the checkpoint source and the provided Tensor name. const Tensor* ckpt_path_t; OP_REQUIRES_OK(context, context->input("ckpt_path", &ckpt_path_t)); - const string ckpt_path = *(ckpt_path_t->scalar().data()); + const string ckpt_path = *(ckpt_path_t->scalar().data()); const Tensor* old_tensor_name_t; OP_REQUIRES_OK(context, context->input("old_tensor_name", &old_tensor_name_t)); const string old_tensor_name = - *(old_tensor_name_t->scalar().data()); + *(old_tensor_name_t->scalar().data()); LOG(INFO) << "Processing checkpoint : " << ckpt_path; BundleReader reader(context->env(), ckpt_path); diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc index f93d3246af4..e4d04c4245c 100644 --- a/tensorflow/core/kernels/logging_ops.cc +++ b/tensorflow/core/kernels/logging_ops.cc @@ -143,7 +143,7 @@ class PrintV2Op : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor* input_; OP_REQUIRES_OK(ctx, ctx->input("input", &input_)); - const string& msg = input_->scalar()(); + const string& msg = input_->scalar()(); string ended_msg = strings::StrCat(msg, end_); diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc index 6e77e1ee012..83721b2cea4 100644 --- a/tensorflow/core/kernels/lookup_table_init_op.cc +++ b/tensorflow/core/kernels/lookup_table_init_op.cc @@ -130,7 +130,7 @@ class InitializeTableFromTextFileOp : public OpKernel { errors::InvalidArgument("filename should be a single string, but got ", vocab_filename_tensor.shape().DebugString())); - string vocab_filename = vocab_filename_tensor.scalar()(); + string vocab_filename = vocab_filename_tensor.scalar()(); OP_REQUIRES(ctx, !vocab_filename.empty(), errors::InvalidArgument("filename cannot be empty.")); diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h index 28a3d94e579..28d63cbf797 100644 --- a/tensorflow/core/kernels/lookup_table_op.h +++ b/tensorflow/core/kernels/lookup_table_op.h @@ -92,7 +92,7 @@ class LookupTableOp : public OpKernel { cinfo_.name()); } else { if (!table_handle_set_) { - auto h = table_handle_.AccessTensor(ctx)->template flat(); + auto h = table_handle_.AccessTensor(ctx)->template flat(); h(0) = cinfo_.container(); h(1) = cinfo_.name(); } diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc index c3b80f04ed2..1fe7988aa67 100644 --- a/tensorflow/core/kernels/lookup_util.cc +++ b/tensorflow/core/kernels/lookup_util.cc @@ -238,7 +238,7 @@ class TextFileLineIterator tensor->flat()(0) = value; } break; case DT_STRING: - tensor->flat()(0) = token; + tensor->flat()(0) = token; break; default: valid_ = false; @@ -264,7 +264,7 @@ Status GetTableHandle(const string& input_name, OpKernelContext* ctx, "Lookup table handle must be scalar, but had shape: ", tensor.shape().DebugString()); } - auto h = tensor.flat(); + auto h = tensor.flat(); *container = h(0); *table_handle = h(1); } diff --git a/tensorflow/core/kernels/matching_files_op.cc b/tensorflow/core/kernels/matching_files_op.cc index 7912ca1563c..0ba718c88ec 100644 --- a/tensorflow/core/kernels/matching_files_op.cc +++ b/tensorflow/core/kernels/matching_files_op.cc @@ -40,7 +40,7 @@ class MatchingFilesOp : public OpKernel { errors::InvalidArgument( "Input patterns tensor must be scalar or vector, but had shape: ", patterns_t->shape().DebugString())); - const auto patterns = patterns_t->flat(); + const auto patterns = patterns_t->flat(); int num_patterns = patterns.size(); int num_files = 0; std::vector> all_fnames(num_patterns); @@ -53,7 +53,7 @@ class MatchingFilesOp : public OpKernel { OP_REQUIRES_OK( context, context->allocate_output("filenames", TensorShape({num_files}), &output_t)); - auto output = output_t->vec(); + auto output = output_t->vec(); int index = 0; for (int i = 0; i < num_patterns; ++i) { for (int j = 0; j < all_fnames[i].size(); j++) { diff --git a/tensorflow/core/kernels/parse_tensor_op.cc b/tensorflow/core/kernels/parse_tensor_op.cc index 8e175fe8d4b..d273f671c6c 100644 --- a/tensorflow/core/kernels/parse_tensor_op.cc +++ b/tensorflow/core/kernels/parse_tensor_op.cc @@ -39,7 +39,7 @@ class ParseTensorOp : public OpKernel { "Expected `serialized` to be a scalar, got shape: ", serialized.shape().DebugString())); - auto serialized_t = serialized.scalar(); + auto serialized_t = serialized.scalar(); TensorProto proto; OP_REQUIRES(ctx, ParseProtoUnlimited(&proto, serialized_t()), diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc index 6ed5bb0c752..67e8c943a65 100644 --- a/tensorflow/core/kernels/queue_ops.cc +++ b/tensorflow/core/kernels/queue_ops.cc @@ -84,8 +84,8 @@ class FakeQueueOp : public OpKernel { void Compute(OpKernelContext* context) override { const ResourceHandle& ref = context->input(0).flat()(0); - handle_.AccessTensor(context)->flat()(0) = ref.container(); - handle_.AccessTensor(context)->flat()(1) = ref.name(); + handle_.AccessTensor(context)->flat()(0) = ref.container(); + handle_.AccessTensor(context)->flat()(1) = ref.name(); context->set_output_ref(0, &mu_, handle_.AccessTensor(context)); } diff --git a/tensorflow/core/kernels/reader_ops.cc b/tensorflow/core/kernels/reader_ops.cc index abd16de6a1c..d93197c5b04 100644 --- a/tensorflow/core/kernels/reader_ops.cc +++ b/tensorflow/core/kernels/reader_ops.cc @@ -139,8 +139,8 @@ class ReaderReadUpToOp : public ReaderVerbAsyncOpKernel { context->allocate_output( "values", TensorShape({num_actually_read}), &values)); - auto keys_t = keys->vec(); - auto values_t = values->vec(); + auto keys_t = keys->vec(); + auto values_t = values->vec(); for (int i = 0; i < num_actually_read; ++i) { keys_t(i) = std::move(keys_vec[i]); values_t(i) = std::move(values_vec[i]); @@ -221,7 +221,7 @@ class ReaderRestoreStateOp : public ReaderVerbSyncOpKernel { context, TensorShapeUtils::IsScalar(tensor->shape()), errors::InvalidArgument("Reader state must be scalar, but had shape: ", tensor->shape().DebugString())); - OP_REQUIRES_OK(context, reader->RestoreState(tensor->scalar()())); + OP_REQUIRES_OK(context, reader->RestoreState(tensor->scalar()())); } }; diff --git a/tensorflow/core/kernels/reduce_join_op.cc b/tensorflow/core/kernels/reduce_join_op.cc index 7a81dfd0369..562281ea308 100644 --- a/tensorflow/core/kernels/reduce_join_op.cc +++ b/tensorflow/core/kernels/reduce_join_op.cc @@ -122,7 +122,7 @@ class ReduceJoinOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); - const auto input_flat = input.flat(); + const auto input_flat = input.flat(); const TensorShape& input_shape = input.shape(); const int32 input_dims = input_shape.dims(); @@ -156,7 +156,7 @@ class ReduceJoinOp : public OpKernel { GetOutputShape(index_is_reduced, input_shape, keep_dims_); OP_REQUIRES_OK(context, context->allocate_output("output", output_shape, &output_tensor)); - auto output_flat = output_tensor->flat(); + auto output_flat = output_tensor->flat(); const int64 reduction_iter_size = GetReductionIterSize(reduced_indices, input_shape); diff --git a/tensorflow/core/kernels/regex_full_match_op.cc b/tensorflow/core/kernels/regex_full_match_op.cc index 7edaaad8f78..04da969df12 100644 --- a/tensorflow/core/kernels/regex_full_match_op.cc +++ b/tensorflow/core/kernels/regex_full_match_op.cc @@ -31,14 +31,14 @@ class RegexFullMatchOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor* input_tensor; OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); - const auto& input_flat = input_tensor->flat(); + const auto& input_flat = input_tensor->flat(); const Tensor* pattern_tensor; OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor)); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()), errors::InvalidArgument("Pattern must be scalar, but received ", pattern_tensor->shape().DebugString())); - const string pattern = pattern_tensor->flat()(0); + const string pattern = pattern_tensor->flat()(0); const RE2 match(pattern); OP_REQUIRES(ctx, match.ok(), errors::InvalidArgument("Invalid pattern: ", pattern, @@ -71,7 +71,7 @@ class StaticRegexFullMatchOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor* input_tensor; OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); - const auto& input_flat = input_tensor->flat(); + const auto& input_flat = input_tensor->flat(); Tensor* output_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(), diff --git a/tensorflow/core/kernels/regex_replace_op.cc b/tensorflow/core/kernels/regex_replace_op.cc index a1b948891d6..76c57350c52 100644 --- a/tensorflow/core/kernels/regex_replace_op.cc +++ b/tensorflow/core/kernels/regex_replace_op.cc @@ -44,9 +44,9 @@ Status InternalCompute(const RE2& match, const string& rewrite, } else { TF_RETURN_IF_ERROR( ctx->allocate_output("output", input_tensor->shape(), &output_tensor)); - output_tensor->flat() = input_tensor->flat(); + output_tensor->flat() = input_tensor->flat(); } - auto output_flat = output_tensor->flat(); + auto output_flat = output_tensor->flat(); for (size_t i = 0; i < output_flat.size(); ++i) { if (replace_global) { RE2::GlobalReplace(&output_flat(i), match, rewrite); @@ -70,7 +70,7 @@ class RegexReplaceOp : public OpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()), errors::InvalidArgument("Pattern must be scalar, but received ", pattern_tensor->shape().DebugString())); - const string pattern = pattern_tensor->flat()(0); + const string pattern = pattern_tensor->flat()(0); const RE2 match(pattern); OP_REQUIRES(ctx, match.ok(), errors::InvalidArgument("Invalid pattern: ", pattern, @@ -81,7 +81,7 @@ class RegexReplaceOp : public OpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rewrite_tensor->shape()), errors::InvalidArgument("Rewrite must be scalar, but received ", rewrite_tensor->shape().DebugString())); - const string rewrite = rewrite_tensor->flat()(0); + const string rewrite = rewrite_tensor->flat()(0); OP_REQUIRES_OK(ctx, InternalCompute(match, rewrite, replace_global_, ctx)); } diff --git a/tensorflow/core/kernels/regex_replace_op_test.cc b/tensorflow/core/kernels/regex_replace_op_test.cc index 9691d4a89f5..bfc45e8bc07 100644 --- a/tensorflow/core/kernels/regex_replace_op_test.cc +++ b/tensorflow/core/kernels/regex_replace_op_test.cc @@ -60,7 +60,7 @@ const char kRewrite[] = " "; Tensor GetTestTensor(int batch) { const int sz = TF_ARRAYSIZE(lines); Tensor t(DT_STRING, {batch}); - auto s = t.flat(); + auto s = t.flat(); for (int i = 0; i < batch; ++i) { s(i) = lines[i % sz]; } diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index 26f107f9403..5e01f4d2d33 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -1356,7 +1356,7 @@ RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments( dst_ptr = tensor->flat().data(); break; case DT_STRING: - dst_ptr = tensor->flat().data(); + dst_ptr = tensor->flat().data(); break; case DT_INT64: dst_ptr = tensor->flat().data(); diff --git a/tensorflow/core/kernels/restore_op_test.cc b/tensorflow/core/kernels/restore_op_test.cc index b6f15a9dc25..1e6ca10d4a1 100644 --- a/tensorflow/core/kernels/restore_op_test.cc +++ b/tensorflow/core/kernels/restore_op_test.cc @@ -94,7 +94,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { // Input #0 is the file name Tensor input_0(DT_STRING, TensorShape({})); - input_0.scalar()() = filename; + input_0.scalar()() = filename; inputs.push_back({nullptr, &input_0}); // Input #1 is the tensor names @@ -203,7 +203,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { // The 1-d integer tensor { MakeRestoreOp(DT_INT32); - (*mutable_input(1).tensor).scalar()() = tensor_names[1]; + (*mutable_input(1).tensor).scalar()() = tensor_names[1]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({10}); @@ -215,7 +215,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { // The 2-d float tensor { MakeRestoreOp(DT_FLOAT); - (*mutable_input(1).tensor).scalar()() = tensor_names[2]; + (*mutable_input(1).tensor).scalar()() = tensor_names[2]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({2, 4}); @@ -227,7 +227,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { // The 2-d double tensor { MakeRestoreOp(DT_DOUBLE); - (*mutable_input(1).tensor).scalar()() = tensor_names[3]; + (*mutable_input(1).tensor).scalar()() = tensor_names[3]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({2, 4}); @@ -239,7 +239,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { // The 2-d qint8 tensor { MakeRestoreOp(DT_QINT8); - (*mutable_input(1).tensor).scalar()() = tensor_names[4]; + (*mutable_input(1).tensor).scalar()() = tensor_names[4]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({3, 2}); @@ -251,7 +251,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { // The 2-d qint32 tensor { MakeRestoreOp(DT_QINT32); - (*mutable_input(1).tensor).scalar()() = tensor_names[5]; + (*mutable_input(1).tensor).scalar()() = tensor_names[5]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({2, 3}); @@ -264,7 +264,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { // The 1-d uint8 tensor { MakeRestoreOp(DT_UINT8); - (*mutable_input(1).tensor).scalar()() = tensor_names[6]; + (*mutable_input(1).tensor).scalar()() = tensor_names[6]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({11}); @@ -276,7 +276,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { // The 1-d int8 tensor { MakeRestoreOp(DT_INT8); - (*mutable_input(1).tensor).scalar()() = tensor_names[7]; + (*mutable_input(1).tensor).scalar()() = tensor_names[7]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({7}); @@ -288,7 +288,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { // The 1-d int16 tensor { MakeRestoreOp(DT_INT16); - (*mutable_input(1).tensor).scalar()() = tensor_names[8]; + (*mutable_input(1).tensor).scalar()() = tensor_names[8]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({7}); @@ -300,7 +300,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { // The 1-d int64 tensor { MakeRestoreOp(DT_INT64); - (*mutable_input(1).tensor).scalar()() = tensor_names[9]; + (*mutable_input(1).tensor).scalar()() = tensor_names[9]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({9}); @@ -312,18 +312,18 @@ TEST_F(RestoreOpTest, RestoreSimple) { // The 1-d string tensor { MakeRestoreOp(DT_STRING); - (*mutable_input(1).tensor).scalar()() = tensor_names[10]; + (*mutable_input(1).tensor).scalar()() = tensor_names[10]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({2}); EXPECT_TRUE(output->shape().IsSameSize(expected)); - EXPECT_EQ("no", output->flat()(0)); - EXPECT_EQ("yes", output->flat()(1)); + EXPECT_EQ("no", output->flat()(0)); + EXPECT_EQ("yes", output->flat()(1)); } // The 2-d complex64 tensor { MakeRestoreOp(DT_COMPLEX64); - (*mutable_input(1).tensor).scalar()() = tensor_names[11]; + (*mutable_input(1).tensor).scalar()() = tensor_names[11]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({2, 3}); @@ -335,7 +335,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { // The 2-d half tensor { MakeRestoreOp(DT_HALF); - (*mutable_input(1).tensor).scalar()() = tensor_names[12]; + (*mutable_input(1).tensor).scalar()() = tensor_names[12]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({2, 4}); @@ -348,7 +348,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { // The 2-d empty float tensor { MakeRestoreOp(DT_FLOAT); - (*mutable_input(1).tensor).scalar()() = tensor_names[13]; + (*mutable_input(1).tensor).scalar()() = tensor_names[13]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({2, 0}); @@ -398,12 +398,12 @@ TEST_F(RestoreSliceOpTest, RestoreInt) { // Input #0 is the file name Tensor input_0(DT_STRING, TensorShape({})); - input_0.scalar()() = filename; + input_0.scalar()() = filename; inputs.push_back({nullptr, &input_0}); // Input #1 is the tensor name Tensor input_1(DT_STRING, TensorShape({})); - input_1.scalar()() = tensor_name; + input_1.scalar()() = tensor_name; inputs.push_back({nullptr, &input_1}); // Input #2 is a 4x16 integer tensor. diff --git a/tensorflow/core/kernels/restore_v2_op_test.cc b/tensorflow/core/kernels/restore_v2_op_test.cc index 36631570c7b..22eb99d2153 100644 --- a/tensorflow/core/kernels/restore_v2_op_test.cc +++ b/tensorflow/core/kernels/restore_v2_op_test.cc @@ -105,7 +105,7 @@ class RestoreV2OpTest : public OpsTestBase { // Input #0 is the file name Tensor input_0(DT_STRING, TensorShape({})); - input_0.scalar()() = filename; + input_0.scalar()() = filename; inputs.push_back({nullptr, &input_0}); // Input #1 is the tensor names @@ -213,7 +213,7 @@ class RestoreV2OpTest : public OpsTestBase { // The 1-d integer tensor { MakeRestoreOp(DT_INT32); - (*mutable_input(1).tensor).flat()(0) = tensor_names[1]; + (*mutable_input(1).tensor).flat()(0) = tensor_names[1]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({10}); @@ -225,7 +225,7 @@ class RestoreV2OpTest : public OpsTestBase { // The 2-d float tensor { MakeRestoreOp(DT_FLOAT); - (*mutable_input(1).tensor).flat()(0) = tensor_names[2]; + (*mutable_input(1).tensor).flat()(0) = tensor_names[2]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({2, 4}); @@ -237,7 +237,7 @@ class RestoreV2OpTest : public OpsTestBase { // The 2-d double tensor { MakeRestoreOp(DT_DOUBLE); - (*mutable_input(1).tensor).flat()(0) = tensor_names[3]; + (*mutable_input(1).tensor).flat()(0) = tensor_names[3]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({2, 4}); @@ -249,7 +249,7 @@ class RestoreV2OpTest : public OpsTestBase { // The 2-d qint8 tensor { MakeRestoreOp(DT_QINT8); - (*mutable_input(1).tensor).flat()(0) = tensor_names[4]; + (*mutable_input(1).tensor).flat()(0) = tensor_names[4]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({3, 2}); @@ -261,7 +261,7 @@ class RestoreV2OpTest : public OpsTestBase { // The 2-d qint32 tensor { MakeRestoreOp(DT_QINT32); - (*mutable_input(1).tensor).flat()(0) = tensor_names[5]; + (*mutable_input(1).tensor).flat()(0) = tensor_names[5]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({2, 3}); @@ -274,7 +274,7 @@ class RestoreV2OpTest : public OpsTestBase { // The 1-d uint8 tensor { MakeRestoreOp(DT_UINT8); - (*mutable_input(1).tensor).flat()(0) = tensor_names[6]; + (*mutable_input(1).tensor).flat()(0) = tensor_names[6]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({11}); @@ -286,7 +286,7 @@ class RestoreV2OpTest : public OpsTestBase { // The 1-d int8 tensor { MakeRestoreOp(DT_INT8); - (*mutable_input(1).tensor).flat()(0) = tensor_names[7]; + (*mutable_input(1).tensor).flat()(0) = tensor_names[7]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({7}); @@ -298,7 +298,7 @@ class RestoreV2OpTest : public OpsTestBase { // The 1-d int16 tensor { MakeRestoreOp(DT_INT16); - (*mutable_input(1).tensor).flat()(0) = tensor_names[8]; + (*mutable_input(1).tensor).flat()(0) = tensor_names[8]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({7}); @@ -310,7 +310,7 @@ class RestoreV2OpTest : public OpsTestBase { // The 1-d int64 tensor { MakeRestoreOp(DT_INT64); - (*mutable_input(1).tensor).flat()(0) = tensor_names[9]; + (*mutable_input(1).tensor).flat()(0) = tensor_names[9]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({9}); @@ -322,7 +322,7 @@ class RestoreV2OpTest : public OpsTestBase { // The 2-d complex64 tensor { MakeRestoreOp(DT_COMPLEX64); - (*mutable_input(1).tensor).flat()(0) = tensor_names[10]; + (*mutable_input(1).tensor).flat()(0) = tensor_names[10]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({2, 3}); @@ -334,7 +334,7 @@ class RestoreV2OpTest : public OpsTestBase { // The 2-d half tensor { MakeRestoreOp(DT_HALF); - (*mutable_input(1).tensor).flat()(0) = tensor_names[11]; + (*mutable_input(1).tensor).flat()(0) = tensor_names[11]; TF_ASSERT_OK(RunOpKernel()); Tensor* output = GetOutput(0); TensorShape expected({2, 4}); diff --git a/tensorflow/core/kernels/save_op.cc b/tensorflow/core/kernels/save_op.cc index f87e0fa0e9c..f53976cae28 100644 --- a/tensorflow/core/kernels/save_op.cc +++ b/tensorflow/core/kernels/save_op.cc @@ -62,8 +62,8 @@ class ShardedFilenameOp : public OpKernel { } Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); - out->scalar()() = strings::Printf( - "%s-%05d-of-%05d", ctx->input(0).scalar()().c_str(), + out->scalar()() = strings::Printf( + "%s-%05d-of-%05d", ctx->input(0).scalar()().c_str(), ctx->input(1).scalar()(), ctx->input(2).scalar()()); } }; @@ -85,8 +85,8 @@ class ShardedFilespecOp : public OpKernel { } Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); - out->scalar()() = strings::Printf( - "%s-\?\?\?\?\?-of-%05d", ctx->input(0).scalar()().c_str(), + out->scalar()() = strings::Printf( + "%s-\?\?\?\?\?-of-%05d", ctx->input(0).scalar()().c_str(), ctx->input(1).scalar()()); } }; diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index faafed367d3..f0a286721e1 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -70,7 +70,7 @@ void SaveTensors( "shapes and slices but got ", tensor_shapes_and_slices_t.NumElements())); tensor_shapes_and_slices_ptr = - tensor_shapes_and_slices_t.flat().data(); + tensor_shapes_and_slices_t.flat().data(); } OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs, errors::InvalidArgument("Expected totally ", N + kFixedInputs, @@ -79,13 +79,13 @@ void SaveTensors( N, " names, but received ", context->num_inputs(), " inputs")); - VLOG(1) << "About to save tensors to file " << filename_t.flat()(0) + VLOG(1) << "About to save tensors to file " << filename_t.flat()(0) << "..."; - checkpoint::TensorSliceWriter writer(filename_t.flat()(0), + checkpoint::TensorSliceWriter writer(filename_t.flat()(0), std::move(builder_func)); Status s; - auto tensor_names_flat = tensor_names_t.flat(); + auto tensor_names_flat = tensor_names_t.flat(); // Process tensors in sorted name order. This allows us to avoid seeking // during restoration in the common case where we are restoring a full @@ -153,10 +153,10 @@ void RestoreTensor(OpKernelContext* context, "Input 0 (file_pattern) must be a string scalar; got a tensor of ", size, "elements")); } - const string& file_pattern = file_pattern_t.flat()(0); + const string& file_pattern = file_pattern_t.flat()(0); const Tensor& tensor_name_t = context->input(1); - const string& tensor_name = tensor_name_t.flat()(restore_index); + const string& tensor_name = tensor_name_t.flat()(restore_index); // If we cannot find a cached reader we will allocate our own. std::unique_ptr allocated_reader; @@ -192,7 +192,7 @@ void RestoreTensor(OpKernelContext* context, TensorShape output_shape(saved_shape); TensorSlice slice_to_load(saved_shape.dims()); if (restore_slice) { - const string& shape_spec = context->input(2).flat()(restore_index); + const string& shape_spec = context->input(2).flat()(restore_index); if (!shape_spec.empty()) { TensorShape parsed_shape; OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( @@ -318,10 +318,10 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, const Tensor& tensor_names, const Tensor& shape_and_slices, gtl::ArraySlice dtypes) { - const string& prefix_string = prefix.scalar()(); + const string& prefix_string = prefix.scalar()(); - const auto& tensor_names_flat = tensor_names.flat(); - const auto& shape_and_slices_flat = shape_and_slices.flat(); + const auto& tensor_names_flat = tensor_names.flat(); + const auto& shape_and_slices_flat = shape_and_slices.flat(); // Sort lookup keys to improve locality when reading multiple tensors. std::vector sorted_name_idx(tensor_names_flat.size()); diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc index ed1195c0535..512fd9bebfe 100644 --- a/tensorflow/core/kernels/save_restore_v2_ops.cc +++ b/tensorflow/core/kernels/save_restore_v2_ops.cc @@ -101,9 +101,9 @@ class SaveV2 : public OpKernel { const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices. const int num_tensors = static_cast(tensor_names.NumElements()); - const string& prefix_string = prefix.scalar()(); - const auto& tensor_names_flat = tensor_names.flat(); - const auto& shape_and_slices_flat = shape_and_slices.flat(); + const string& prefix_string = prefix.scalar()(); + const auto& tensor_names_flat = tensor_names.flat(); + const auto& shape_and_slices_flat = shape_and_slices.flat(); BundleWriter writer(Env::Default(), prefix_string); OP_REQUIRES_OK(context, writer.status()); @@ -157,7 +157,7 @@ class RestoreV2 : public OpKernel { ValidateInputs(false /* not save op */, context, prefix, tensor_names, shape_and_slices); - const string& prefix_string = prefix.scalar()(); + const string& prefix_string = prefix.scalar()(); // Intention: we plan to use the RestoreV2 op as a backward-compatible // reader as we upgrade to the V2 format. This allows transparent upgrade. @@ -215,7 +215,7 @@ class MergeV2Checkpoints : public OpKernel { const gtl::ArraySlice input_prefixes = gtl::ArraySlice(checkpoint_prefixes.flat()); Env* env = Env::Default(); - const string& merged_prefix = destination_prefix.scalar()(); + const string& merged_prefix = destination_prefix.scalar()(); OP_REQUIRES_OK( context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix)); diff --git a/tensorflow/core/kernels/sdca_ops.cc b/tensorflow/core/kernels/sdca_ops.cc index d0e0b15da78..4fdb7d1e257 100644 --- a/tensorflow/core/kernels/sdca_ops.cc +++ b/tensorflow/core/kernels/sdca_ops.cc @@ -312,7 +312,7 @@ class SdcaFprint : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output( 0, TensorShape({num_elements, 2}), &out)); - const auto in_values = input.flat(); + const auto in_values = input.flat(); auto out_values = out->matrix(); for (int64 i = 0; i < num_elements; ++i) { diff --git a/tensorflow/core/kernels/session_ops.cc b/tensorflow/core/kernels/session_ops.cc index f2dd2812b53..d83a714452f 100644 --- a/tensorflow/core/kernels/session_ops.cc +++ b/tensorflow/core/kernels/session_ops.cc @@ -57,7 +57,7 @@ class GetSessionHandleOp : public OpKernel { handle->scalar()() = resource_handle; } else { // Legacy behavior in V1. - handle->flat().setConstant(tk.GetHandle(name())); + handle->flat().setConstant(tk.GetHandle(name())); } } @@ -110,7 +110,7 @@ class GetSessionTensorOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& handle = ctx->input(0); - const string& name = handle.scalar()(); + const string& name = handle.scalar()(); Tensor val; OP_REQUIRES_OK(ctx, ctx->session_state()->GetTensor(name, &val)); ctx->set_output(0, val); @@ -153,7 +153,7 @@ class DeleteSessionTensorOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& handle = ctx->input(0); - const string& name = handle.scalar()(); + const string& name = handle.scalar()(); OP_REQUIRES_OK(ctx, ctx->session_state()->DeleteTensor(name)); } diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc index 8e92c9e5517..a6a4060b26e 100644 --- a/tensorflow/core/kernels/sparse_cross_op.cc +++ b/tensorflow/core/kernels/sparse_cross_op.cc @@ -78,7 +78,7 @@ template <> int64 SparseTensorColumn::Feature(int64 batch, int64 n) const { const int64 start = feature_start_indices_[batch]; if (DT_STRING == values_.dtype()) - return Fingerprint64(values_.vec().data()[start + n]); + return Fingerprint64(values_.vec().data()[start + n]); return values_.vec().data()[start + n]; } @@ -87,7 +87,7 @@ template <> string SparseTensorColumn::Feature(int64 batch, int64 n) const { const int64 start = feature_start_indices_[batch]; if (DT_STRING == values_.dtype()) - return values_.vec().data()[start + n]; + return values_.vec().data()[start + n]; return std::to_string(values_.vec().data()[start + n]); } @@ -95,7 +95,7 @@ template <> StringPiece SparseTensorColumn::Feature(int64 batch, int64 n) const { const int64 start = feature_start_indices_[batch]; - return values_.vec().data()[start + n]; + return values_.vec().data()[start + n]; } // A column that is backed by a dense tensor. @@ -118,21 +118,21 @@ class DenseTensorColumn : public ColumnInterface { template <> int64 DenseTensorColumn::Feature(int64 batch, int64 n) const { if (DT_STRING == tensor_.dtype()) - return Fingerprint64(tensor_.matrix()(batch, n)); + return Fingerprint64(tensor_.matrix()(batch, n)); return tensor_.matrix()(batch, n); } // Internal type is string or StringPiece when using StringCrosser. template <> string DenseTensorColumn::Feature(int64 batch, int64 n) const { - if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); + if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); return std::to_string(tensor_.matrix()(batch, n)); } template <> StringPiece DenseTensorColumn::Feature(int64 batch, int64 n) const { - return tensor_.matrix()(batch, n); + return tensor_.matrix()(batch, n); } // Updates Output tensors with sparse crosses. diff --git a/tensorflow/core/kernels/stack.cc b/tensorflow/core/kernels/stack.cc index d6a79049277..af8f760d47f 100644 --- a/tensorflow/core/kernels/stack.cc +++ b/tensorflow/core/kernels/stack.cc @@ -134,8 +134,8 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) { "Stack handle must have two elements, but had shape: ", Tstack_handle.shape().DebugString()); } - const string& container = Tstack_handle.flat()(0); - const string& stack_name = Tstack_handle.flat()(1); + const string& container = Tstack_handle.flat()(0); + const string& stack_name = Tstack_handle.flat()(1); string key = strings::StrCat(container, stack_name); ResourceMgr* rm = ctx->resource_manager(); if (rm == nullptr) { @@ -196,7 +196,7 @@ void StackOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_STRING, tensorflow::TensorShape({2}), &stack->handle_, alloc_attr)); - auto handle = stack->handle_.flat(); + auto handle = stack->handle_.flat(); handle(0) = kContainer; handle(1) = std::move(stack_name); ctx->set_output_ref(0, stack->mu(), &stack->handle_); diff --git a/tensorflow/core/kernels/string_format_op.cc b/tensorflow/core/kernels/string_format_op.cc index e4a1887f8d3..e42854cedd3 100644 --- a/tensorflow/core/kernels/string_format_op.cc +++ b/tensorflow/core/kernels/string_format_op.cc @@ -50,7 +50,7 @@ class StringFormatOp : public OpKernel { strings::StrAppend(&msg, split_template_[i + 1].c_str()); } - formatted_string->scalar()() = msg; + formatted_string->scalar()() = msg; } private: diff --git a/tensorflow/core/kernels/string_join_op.cc b/tensorflow/core/kernels/string_join_op.cc index 4b9c19da691..5532f6d6fe9 100644 --- a/tensorflow/core/kernels/string_join_op.cc +++ b/tensorflow/core/kernels/string_join_op.cc @@ -42,7 +42,7 @@ class StringJoinOp : public OpKernel { std::vector::ConstFlat> inputs; for (const auto& input : input_list) { - inputs.push_back(input.flat()); + inputs.push_back(input.flat()); is_scalar.push_back(TensorShapeUtils::IsScalar(input.shape())); if (!TensorShapeUtils::IsScalar(input.shape())) { if (TensorShapeUtils::IsScalar(input_shape)) { @@ -60,7 +60,7 @@ class StringJoinOp : public OpKernel { Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output("output", input_shape, &output_tensor)); - auto output_flat = output_tensor->flat(); + auto output_flat = output_tensor->flat(); std::vector strings(input_list.size()); for (size_t i = 0; i < input_shape.num_elements(); ++i) { diff --git a/tensorflow/core/kernels/string_length_op.cc b/tensorflow/core/kernels/string_length_op.cc index 435a7abdcac..53a161353f0 100644 --- a/tensorflow/core/kernels/string_length_op.cc +++ b/tensorflow/core/kernels/string_length_op.cc @@ -34,7 +34,7 @@ class StringLengthOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); - auto src = input.flat(); + auto src = input.flat(); auto dst = output->flat(); switch (unit_) { diff --git a/tensorflow/core/kernels/string_lower_op.cc b/tensorflow/core/kernels/string_lower_op.cc index e24eedcc3ae..07065d2777e 100644 --- a/tensorflow/core/kernels/string_lower_op.cc +++ b/tensorflow/core/kernels/string_lower_op.cc @@ -45,8 +45,8 @@ class StringLowerOp : public OpKernel { OP_REQUIRES_OK( ctx, ctx->allocate_output(0, input_tensor->shape(), &output_tensor)); - const auto input = input_tensor->flat(); - auto output = output_tensor->flat(); + const auto input = input_tensor->flat(); + auto output = output_tensor->flat(); if (encoding_.empty()) { for (int64 i = 0; i < input.size(); ++i) { diff --git a/tensorflow/core/kernels/string_ngrams_op.cc b/tensorflow/core/kernels/string_ngrams_op.cc index 37a7aa956d0..430d91bef88 100644 --- a/tensorflow/core/kernels/string_ngrams_op.cc +++ b/tensorflow/core/kernels/string_ngrams_op.cc @@ -54,14 +54,14 @@ class StringNGramsOp : public tensorflow::OpKernel { void Compute(tensorflow::OpKernelContext* context) override { const tensorflow::Tensor* data; OP_REQUIRES_OK(context, context->input("data", &data)); - const auto& input_data = data->flat().data(); + const auto& input_data = data->flat().data(); const tensorflow::Tensor* splits; OP_REQUIRES_OK(context, context->input("data_splits", &splits)); const auto& splits_vec = splits->flat(); // If there is no data or size, return an empty RT. - if (data->flat().size() == 0 || splits_vec.size() == 0) { + if (data->flat().size() == 0 || splits_vec.size() == 0) { tensorflow::Tensor* empty; OP_REQUIRES_OK(context, context->allocate_output(0, data->shape(), &empty)); @@ -93,7 +93,7 @@ class StringNGramsOp : public tensorflow::OpKernel { context, context->allocate_output( 0, TensorShape({ngrams_splits_data[num_batch_items]}), &ngrams)); - auto ngrams_data = ngrams->flat().data(); + auto ngrams_data = ngrams->flat().data(); for (int i = 0; i < num_batch_items; ++i) { auto data_start = &input_data[splits_vec(i)]; diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc index 3884370a6c6..d6d27debf89 100644 --- a/tensorflow/core/kernels/string_split_op.cc +++ b/tensorflow/core/kernels/string_split_op.cc @@ -178,7 +178,7 @@ class StringSplitOp : public OpKernel { errors::InvalidArgument("input must be a vector, got shape: ", input_tensor->shape().DebugString())); - const auto input_vec = input_tensor->vec(); + const auto input_vec = input_tensor->vec(); const int64 batch_size = input_vec.dimension(0); const Tensor* delimiter_tensor; @@ -220,7 +220,7 @@ class StringSplitOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t)); auto sp_indices = sp_indices_t->matrix(); - auto sp_tokens = sp_tokens_t->vec(); + auto sp_tokens = sp_tokens_t->vec(); auto sp_shape = sp_shape_t->vec(); sp_shape(0) = batch_size; sp_shape(1) = max_num_entries; @@ -253,7 +253,7 @@ class StringSplitV2Op : public OpKernel { errors::InvalidArgument("input must be a vector, got shape: ", input_tensor->shape().DebugString())); - const auto input_vec = input_tensor->vec(); + const auto input_vec = input_tensor->vec(); const int64 batch_size = input_vec.dimension(0); const Tensor* sep_tensor; @@ -261,7 +261,7 @@ class StringSplitV2Op : public OpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sep_tensor->shape()), errors::InvalidArgument("sep must be a scalar, got shape: ", sep_tensor->shape().DebugString())); - const auto sep_vec = sep_tensor->flat(); + const auto sep_vec = sep_tensor->flat(); StringPiece sep(sep_vec(0)); std::vector tokens; // Guess that we'll be unpacking a handful of tokens per example. @@ -290,7 +290,7 @@ class StringSplitV2Op : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t)); auto sp_indices = sp_indices_t->matrix(); - auto sp_tokens = sp_tokens_t->vec(); + auto sp_tokens = sp_tokens_t->vec(); auto sp_shape = sp_shape_t->vec(); sp_shape(0) = batch_size; sp_shape(1) = max_num_entries; diff --git a/tensorflow/core/kernels/string_split_op_test.cc b/tensorflow/core/kernels/string_split_op_test.cc index 58ad61adc86..4494cf9dcf3 100644 --- a/tensorflow/core/kernels/string_split_op_test.cc +++ b/tensorflow/core/kernels/string_split_op_test.cc @@ -57,7 +57,7 @@ const char* lines[] = { Tensor GetTestTensor(int batch) { const int sz = TF_ARRAYSIZE(lines); Tensor t(DT_STRING, {batch}); - auto s = t.flat(); + auto s = t.flat(); for (int i = 0; i < batch; ++i) { s(i) = lines[i % sz]; } @@ -67,7 +67,7 @@ Tensor GetTestTensor(int batch) { Graph* SetupStringSplitGraph(const Tensor& input) { Graph* g = new Graph(OpRegistry::Global()); Tensor delim(DT_STRING, TensorShape({})); - delim.flat().setConstant(" "); + delim.flat().setConstant(" "); TF_CHECK_OK(NodeBuilder("string_split_op", "StringSplit") .Input(test::graph::Constant(g, input)) @@ -98,7 +98,7 @@ BENCHMARK(BM_StringSplit) Graph* SetupStringSplitV2Graph(const Tensor& input) { Graph* g = new Graph(OpRegistry::Global()); Tensor sep(DT_STRING, TensorShape({})); - sep.flat().setConstant(" "); + sep.flat().setConstant(" "); TF_CHECK_OK(NodeBuilder("string_split_op", "StringSplitV2") .Input(test::graph::Constant(g, input)) diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc index 544dca96ba7..715ec271db5 100644 --- a/tensorflow/core/kernels/string_strip_op.cc +++ b/tensorflow/core/kernels/string_strip_op.cc @@ -37,8 +37,8 @@ class StringStripOp : public OpKernel { OP_REQUIRES_OK( ctx, ctx->allocate_output(0, input_tensor->shape(), &output_tensor)); - const auto input = input_tensor->flat(); - auto output = output_tensor->flat(); + const auto input = input_tensor->flat(); + auto output = output_tensor->flat(); for (int64 i = 0; i < input.size(); ++i) { StringPiece entry(input(i)); diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.cc b/tensorflow/core/kernels/string_to_hash_bucket_op.cc index 10fc6ee5434..1505ddbb9bc 100644 --- a/tensorflow/core/kernels/string_to_hash_bucket_op.cc +++ b/tensorflow/core/kernels/string_to_hash_bucket_op.cc @@ -33,7 +33,7 @@ class LegacyStringToHashBucketOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor* input_tensor; OP_REQUIRES_OK(context, context->input("string_tensor", &input_tensor)); - const auto& input_flat = input_tensor->flat(); + const auto& input_flat = input_tensor->flat(); Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.h b/tensorflow/core/kernels/string_to_hash_bucket_op.h index 62ef35bbba4..8647695cf46 100644 --- a/tensorflow/core/kernels/string_to_hash_bucket_op.h +++ b/tensorflow/core/kernels/string_to_hash_bucket_op.h @@ -36,7 +36,7 @@ class StringToHashBucketOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor* input_tensor; OP_REQUIRES_OK(context, context->input("input", &input_tensor)); - const auto& input_flat = input_tensor->flat(); + const auto& input_flat = input_tensor->flat(); Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, @@ -78,7 +78,7 @@ class StringToKeyedHashBucketOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor* input_tensor; OP_REQUIRES_OK(context, context->input("input", &input_tensor)); - const auto& input_flat = input_tensor->flat(); + const auto& input_flat = input_tensor->flat(); Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, diff --git a/tensorflow/core/kernels/string_to_number_op.cc b/tensorflow/core/kernels/string_to_number_op.cc index 22742dd38e5..8340f35428b 100644 --- a/tensorflow/core/kernels/string_to_number_op.cc +++ b/tensorflow/core/kernels/string_to_number_op.cc @@ -40,7 +40,7 @@ class StringToNumberOp : public OpKernel { // underlying storage. const Tensor* input_tensor; OP_REQUIRES_OK(context, context->input("string_tensor", &input_tensor)); - const auto& input_flat = input_tensor->flat(); + const auto& input_flat = input_tensor->flat(); Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, diff --git a/tensorflow/core/kernels/string_upper_op.cc b/tensorflow/core/kernels/string_upper_op.cc index f2a1d33e7a6..d9f088a7b78 100644 --- a/tensorflow/core/kernels/string_upper_op.cc +++ b/tensorflow/core/kernels/string_upper_op.cc @@ -45,8 +45,8 @@ class StringUpperOp : public OpKernel { OP_REQUIRES_OK( ctx, ctx->allocate_output(0, input_tensor->shape(), &output_tensor)); - const auto input = input_tensor->flat(); - auto output = output_tensor->flat(); + const auto input = input_tensor->flat(); + auto output = output_tensor->flat(); if (encoding_.empty()) { for (int64 i = 0; i < input.size(); ++i) { StringPiece entry(input(i)); diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index 77b16b9384d..458d67ccc5e 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -59,13 +59,13 @@ class SubstrOp : public OpKernel { // Do not need to do broadcasting // Reshape input - auto input = input_tensor.flat(); + auto input = input_tensor.flat(); // Allocate output Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output("output", input_tensor.shape(), &output_tensor)); - auto output = output_tensor->flat(); + auto output = output_tensor->flat(); if (is_scalar) { // Perform Op with scalar pos/len const T pos = @@ -141,8 +141,8 @@ class SubstrOp : public OpKernel { switch (ndims) { case 1: { // Reshape tensors according to BCast results - auto input = input_tensor.shaped(bcast.x_reshape()); - auto output = output_tensor->shaped(bcast.result_shape()); + auto input = input_tensor.shaped(bcast.x_reshape()); + auto output = output_tensor->shaped(bcast.result_shape()); auto pos_shaped = pos_tensor.shaped(bcast.y_reshape()); auto len_shaped = len_tensor.shaped(bcast.y_reshape()); @@ -204,8 +204,8 @@ class SubstrOp : public OpKernel { } case 2: { // Reshape tensors according to BCast results - auto input = input_tensor.shaped(bcast.x_reshape()); - auto output = output_tensor->shaped(bcast.result_shape()); + auto input = input_tensor.shaped(bcast.x_reshape()); + auto output = output_tensor->shaped(bcast.result_shape()); auto pos_shaped = pos_tensor.shaped(bcast.y_reshape()); auto len_shaped = len_tensor.shaped(bcast.y_reshape()); diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc index ea6b1ed5006..3aebfe3a212 100644 --- a/tensorflow/core/kernels/substr_op_test.cc +++ b/tensorflow/core/kernels/substr_op_test.cc @@ -115,7 +115,7 @@ const char* const kUTF8Unit = "UTF8_CHAR"; Tensor GetTestTensor(int batch) { const int sz = TF_ARRAYSIZE(ascii_lines); Tensor t(DT_STRING, {batch}); - auto s = t.flat(); + auto s = t.flat(); for (int i = 0; i < batch; ++i) { s(i) = ascii_lines[i % sz]; } @@ -125,7 +125,7 @@ Tensor GetTestTensor(int batch) { Tensor GetTestUTF8Tensor(int batch) { const int sz = TF_ARRAYSIZE(unicode_lines); Tensor t(DT_STRING, {batch}); - auto s = t.flat(); + auto s = t.flat(); for (int i = 0; i < batch; ++i) { s(i) = unicode_lines[i % sz]; } diff --git a/tensorflow/core/kernels/summary_audio_op.cc b/tensorflow/core/kernels/summary_audio_op.cc index f5ddb9081d6..fbb1c2c6473 100644 --- a/tensorflow/core/kernels/summary_audio_op.cc +++ b/tensorflow/core/kernels/summary_audio_op.cc @@ -44,7 +44,7 @@ class SummaryAudioOp : public OpKernel { OP_REQUIRES(c, tensor.dims() >= 2 && tensor.dims() <= 3, errors::InvalidArgument("Tensor must be 3-D or 2-D, got: ", tensor.shape().DebugString())); - const string& base_tag = tag.scalar()(); + const string& base_tag = tag.scalar()(); float sample_rate = sample_rate_attr_; if (!has_sample_rate_attr_) { diff --git a/tensorflow/core/kernels/summary_audio_op_test.cc b/tensorflow/core/kernels/summary_audio_op_test.cc index 1b957c548b6..7c6ec045b2d 100644 --- a/tensorflow/core/kernels/summary_audio_op_test.cc +++ b/tensorflow/core/kernels/summary_audio_op_test.cc @@ -93,7 +93,7 @@ TEST_F(SummaryAudioOpTest, Basic3D) { Tensor* out_tensor = GetOutput(0); ASSERT_EQ(0, out_tensor->dims()); Summary summary; - ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ParseProtoUnlimited(&summary, out_tensor->scalar()()); CheckAndRemoveEncodedAudio(&summary); EXPECT_SummaryMatches(summary, R"( @@ -127,7 +127,7 @@ TEST_F(SummaryAudioOpTest, Basic2D) { Tensor* out_tensor = GetOutput(0); ASSERT_EQ(0, out_tensor->dims()); Summary summary; - ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ParseProtoUnlimited(&summary, out_tensor->scalar()()); CheckAndRemoveEncodedAudio(&summary); EXPECT_SummaryMatches(summary, R"( diff --git a/tensorflow/core/kernels/summary_image_op.cc b/tensorflow/core/kernels/summary_image_op.cc index 68f17c2e78d..bfba449c782 100644 --- a/tensorflow/core/kernels/summary_image_op.cc +++ b/tensorflow/core/kernels/summary_image_op.cc @@ -61,7 +61,7 @@ class SummaryImageOp : public OpKernel { errors::InvalidArgument( "Tensor must be 4-D with last dim 1, 3, or 4, not ", tensor.shape().DebugString())); - const string& base_tag = tags.scalar()(); + const string& base_tag = tags.scalar()(); OP_REQUIRES(c, tensor.dim_size(0) < (1LL << 31) && diff --git a/tensorflow/core/kernels/summary_image_op_test.cc b/tensorflow/core/kernels/summary_image_op_test.cc index 74e0d092c2d..be8e44d7511 100644 --- a/tensorflow/core/kernels/summary_image_op_test.cc +++ b/tensorflow/core/kernels/summary_image_op_test.cc @@ -87,7 +87,7 @@ TEST_F(SummaryImageOpTest, ThreeGrayImagesOutOfFive4dInput) { Tensor* out_tensor = GetOutput(0); ASSERT_EQ(0, out_tensor->dims()); Summary summary; - ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ParseProtoUnlimited(&summary, out_tensor->scalar()()); CheckAndRemoveEncodedImages(&summary); EXPECT_SummaryMatches(summary, R"( @@ -110,7 +110,7 @@ TEST_F(SummaryImageOpTest, OneGrayImage4dInput) { Tensor* out_tensor = GetOutput(0); ASSERT_EQ(0, out_tensor->dims()); Summary summary; - ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ParseProtoUnlimited(&summary, out_tensor->scalar()()); CheckAndRemoveEncodedImages(&summary); EXPECT_SummaryMatches(summary, R"( @@ -142,7 +142,7 @@ TEST_F(SummaryImageOpTest, OneColorImage4dInput) { Tensor* out_tensor = GetOutput(0); ASSERT_EQ(0, out_tensor->dims()); Summary summary; - ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ParseProtoUnlimited(&summary, out_tensor->scalar()()); CheckAndRemoveEncodedImages(&summary); EXPECT_SummaryMatches(summary, R"( diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc index e17e28efc63..7f888da69d6 100644 --- a/tensorflow/core/kernels/summary_kernels.cc +++ b/tensorflow/core/kernels/summary_kernels.cc @@ -38,13 +38,13 @@ class CreateSummaryFileWriterOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("logdir", &tmp)); - const string logdir = tmp->scalar()(); + const string logdir = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("max_queue", &tmp)); const int32 max_queue = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("flush_millis", &tmp)); const int32 flush_millis = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp)); - const string filename_suffix = tmp->scalar()(); + const string filename_suffix = tmp->scalar()(); core::RefCountPtr s; OP_REQUIRES_OK(ctx, LookupOrCreateResource( @@ -67,13 +67,13 @@ class CreateSummaryDbWriterOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("db_uri", &tmp)); - const string db_uri = tmp->scalar()(); + const string db_uri = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("experiment_name", &tmp)); - const string experiment_name = tmp->scalar()(); + const string experiment_name = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("run_name", &tmp)); - const string run_name = tmp->scalar()(); + const string run_name = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp)); - const string user_name = tmp->scalar()(); + const string user_name = tmp->scalar()(); core::RefCountPtr s; OP_REQUIRES_OK( @@ -132,9 +132,9 @@ class WriteSummaryOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); - const string& tag = tmp->scalar()(); + const string& tag = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("summary_metadata", &tmp)); - const string& serialized_metadata = tmp->scalar()(); + const string& serialized_metadata = tmp->scalar()(); const Tensor* t; OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); @@ -166,7 +166,7 @@ class WriteRawProtoSummaryOp : public OpKernel { // Each Summary proto contains just one repeated field "value" of Value // messages with the actual data, so repeated Merge() is equivalent to // concatenating all the Value entries together into a single Event. - const auto summary_pbs = t->flat(); + const auto summary_pbs = t->flat(); for (int i = 0; i < summary_pbs.size(); ++i) { if (!event->mutable_summary()->MergeFromString(summary_pbs(i))) { ctx->CtxFailureWithWarning(errors::DataLoss( @@ -191,7 +191,7 @@ class ImportEventOp : public OpKernel { const Tensor* t; OP_REQUIRES_OK(ctx, ctx->input("event", &t)); std::unique_ptr event{new Event}; - if (!ParseProtoUnlimited(event.get(), t->scalar()())) { + if (!ParseProtoUnlimited(event.get(), t->scalar()())) { ctx->CtxFailureWithWarning( errors::DataLoss("Bad tf.Event binary proto tensor string")); return; @@ -212,7 +212,7 @@ class WriteScalarSummaryOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); - const string& tag = tmp->scalar()(); + const string& tag = tmp->scalar()(); const Tensor* t; OP_REQUIRES_OK(ctx, ctx->input("value", &t)); @@ -234,7 +234,7 @@ class WriteHistogramSummaryOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); - const string& tag = tmp->scalar()(); + const string& tag = tmp->scalar()(); const Tensor* t; OP_REQUIRES_OK(ctx, ctx->input("values", &t)); @@ -262,7 +262,7 @@ class WriteImageSummaryOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); - const string& tag = tmp->scalar()(); + const string& tag = tmp->scalar()(); const Tensor* bad_color; OP_REQUIRES_OK(ctx, ctx->input("bad_color", &bad_color)); OP_REQUIRES( @@ -297,7 +297,7 @@ class WriteAudioSummaryOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); - const string& tag = tmp->scalar()(); + const string& tag = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("sample_rate", &tmp)); const float sample_rate = tmp->scalar()(); @@ -326,7 +326,7 @@ class WriteGraphSummaryOp : public OpKernel { const int64 step = t->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); std::unique_ptr graph{new GraphDef}; - if (!ParseProtoUnlimited(graph.get(), t->scalar()())) { + if (!ParseProtoUnlimited(graph.get(), t->scalar()())) { ctx->CtxFailureWithWarning( errors::DataLoss("Bad tf.GraphDef binary proto tensor string")); return; diff --git a/tensorflow/core/kernels/summary_op.cc b/tensorflow/core/kernels/summary_op.cc index 1053aa7d53a..a765825e5b0 100644 --- a/tensorflow/core/kernels/summary_op.cc +++ b/tensorflow/core/kernels/summary_op.cc @@ -47,7 +47,7 @@ class SummaryScalarOp : public OpKernel { errors::InvalidArgument( "tags and values not the same shape: ", tags.shape().DebugString(), " != ", values.shape().DebugString(), SingleTag(tags))); - auto Ttags = tags.flat(); + auto Ttags = tags.flat(); auto Tvalues = values.flat(); Summary s; for (int i = 0; i < Ttags.size(); i++) { @@ -64,7 +64,7 @@ class SummaryScalarOp : public OpKernel { // If there's only one tag, include it in the error message static string SingleTag(const Tensor& tags) { if (tags.NumElements() == 1) { - return strings::StrCat(" (tag '", tags.flat()(0), "')"); + return strings::StrCat(" (tag '", tags.flat()(0), "')"); } else { return ""; } @@ -138,7 +138,7 @@ class SummaryMergeOp : public OpKernel { std::unordered_set tags; for (int input_num = 0; input_num < c->num_inputs(); input_num++) { const Tensor& in = c->input(input_num); - auto in_vec = in.flat(); + auto in_vec = in.flat(); for (int i = 0; i < in_vec.dimension(0); i++) { const string& s_in = in_vec(i); Summary summary_in; diff --git a/tensorflow/core/kernels/summary_op_test.cc b/tensorflow/core/kernels/summary_op_test.cc index 697c03a0082..9dcc98eeefe 100644 --- a/tensorflow/core/kernels/summary_op_test.cc +++ b/tensorflow/core/kernels/summary_op_test.cc @@ -88,7 +88,7 @@ TEST_F(SummaryScalarOpTest, SimpleDouble) { Tensor* out_tensor = GetOutput(0); ASSERT_EQ(0, out_tensor->dims()); Summary summary; - ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ParseProtoUnlimited(&summary, out_tensor->scalar()()); EXPECT_SummaryMatches(summary, R"( value { tag: 'tag1' simple_value: 1.0 } value { tag: 'tag2' simple_value: -0.73 } @@ -100,7 +100,7 @@ TEST_F(SummaryScalarOpTest, SimpleHalf) { MakeOp(DT_HALF); // Feed and run - AddInputFromList(TensorShape({3}), {"tag1", "tag2", "tag3"}); + AddInputFromList(TensorShape({3}), {"tag1", "tag2", "tag3"}); AddInputFromList(TensorShape({3}), {1.0, -2.0, 10000.0}); TF_ASSERT_OK(RunOpKernel()); @@ -108,7 +108,7 @@ TEST_F(SummaryScalarOpTest, SimpleHalf) { Tensor* out_tensor = GetOutput(0); ASSERT_EQ(0, out_tensor->dims()); Summary summary; - ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ParseProtoUnlimited(&summary, out_tensor->scalar()()); EXPECT_SummaryMatches(summary, R"( value { tag: 'tag1' simple_value: 1.0 } value { tag: 'tag2' simple_value: -2.0 } @@ -177,7 +177,7 @@ TEST_F(SummaryHistoOpTest, SimpleFloat) { Tensor* out_tensor = GetOutput(0); ASSERT_EQ(0, out_tensor->dims()); Summary summary; - ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ParseProtoUnlimited(&summary, out_tensor->scalar()()); ASSERT_EQ(summary.value_size(), 1); EXPECT_EQ(summary.value(0).tag(), "taghisto"); histogram::Histogram histo; @@ -205,7 +205,7 @@ TEST_F(SummaryHistoOpTest, SimpleDouble) { Tensor* out_tensor = GetOutput(0); ASSERT_EQ(0, out_tensor->dims()); Summary summary; - ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ParseProtoUnlimited(&summary, out_tensor->scalar()()); ASSERT_EQ(summary.value_size(), 1); EXPECT_EQ(summary.value(0).tag(), "taghisto"); histogram::Histogram histo; @@ -234,7 +234,7 @@ TEST_F(SummaryHistoOpTest, SimpleHalf) { Tensor* out_tensor = GetOutput(0); ASSERT_EQ(0, out_tensor->dims()); Summary summary; - ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ParseProtoUnlimited(&summary, out_tensor->scalar()()); ASSERT_EQ(summary.value_size(), 1); EXPECT_EQ(summary.value(0).tag(), "taghisto"); histogram::Histogram histo; @@ -308,7 +308,7 @@ TEST_F(SummaryMergeOpTest, Simple) { Tensor* out_tensor = GetOutput(0); ASSERT_EQ(0, out_tensor->dims()); Summary summary; - ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ParseProtoUnlimited(&summary, out_tensor->scalar()()); EXPECT_SummaryMatches(summary, "value { tag: \"tag1\" simple_value: 1.0 } " @@ -342,7 +342,7 @@ TEST_F(SummaryMergeOpTest, Simple_MultipleInputs) { Tensor* out_tensor = GetOutput(0); ASSERT_EQ(0, out_tensor->dims()); Summary summary; - ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ParseProtoUnlimited(&summary, out_tensor->scalar()()); EXPECT_SummaryMatches(summary, "value { tag: \"tag1\" simple_value: 1.0 } " diff --git a/tensorflow/core/kernels/summary_tensor_op_test.cc b/tensorflow/core/kernels/summary_tensor_op_test.cc index 55a0cb3ec5a..6bc4d150c2a 100644 --- a/tensorflow/core/kernels/summary_tensor_op_test.cc +++ b/tensorflow/core/kernels/summary_tensor_op_test.cc @@ -80,14 +80,14 @@ TEST_F(SummaryTensorOpV2Test, BasicPluginData) { Tensor* out_tensor = GetOutput(0); ASSERT_EQ(0, out_tensor->dims()); Summary summary; - ParseProtoUnlimited(&summary, out_tensor->scalar()()); + ParseProtoUnlimited(&summary, out_tensor->scalar()()); ASSERT_EQ(1, summary.value_size()); // Check the content of the tensor stored in the summary. Tensor string_content_tensor; CHECK(string_content_tensor.FromProto(summary.value(0).tensor())); ASSERT_EQ("some string tensor content", - string_content_tensor.scalar()()); + string_content_tensor.scalar()()); // Check plugin-related data. ASSERT_EQ("tag_foo", summary.value(0).tag()); diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc index 8e8faf89837..2bd6ac0b08d 100644 --- a/tensorflow/core/kernels/tensor_array.cc +++ b/tensorflow/core/kernels/tensor_array.cc @@ -91,8 +91,8 @@ Status TensorArray::CopyShapesFrom(TensorArray* rhs, if (tensors_.size() != rhs->tensors_.size()) { return errors::InvalidArgument( "TensorArray sizes do not match during CopyShapesFrom: ", - handle_.vec()(1), " has size ", tensors_.size(), " but rhs ", - rhs->handle_.vec()(1), " has size ", rhs->tensors_.size()); + handle_.vec()(1), " has size ", tensors_.size(), " but rhs ", + rhs->handle_.vec()(1), " has size ", rhs->tensors_.size()); } for (std::size_t i = 0; i < tensors_.size(); ++i) { // Skip "soft copy" of indices which have not been written. diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h index 964b4631023..bea97d1a1f1 100644 --- a/tensorflow/core/kernels/tensor_array.h +++ b/tensorflow/core/kernels/tensor_array.h @@ -365,7 +365,7 @@ class TensorArray : public ResourceBase { Status LockedReturnIfClosed() const EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (closed_) { - return errors::InvalidArgument("TensorArray ", handle_.vec()(1), + return errors::InvalidArgument("TensorArray ", handle_.vec()(1), " has already been closed."); } return Status::OK(); @@ -447,7 +447,7 @@ Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, size_t index_size = static_cast(index); if (index < 0 || (!dynamic_size_ && index_size >= tensors_.size())) { return errors::InvalidArgument( - "TensorArray ", handle_.vec()(1), ": Tried to write to index ", + "TensorArray ", handle_.vec()(1), ": Tried to write to index ", index, " but array is not resizeable and size is: ", tensors_.size()); } if (dynamic_size_) { @@ -464,14 +464,14 @@ Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, Tensor* value_t = value->AccessTensor(ctx); if (value_t->dtype() != dtype_) { return errors::InvalidArgument( - "TensorArray ", handle_.vec()(1), + "TensorArray ", handle_.vec()(1), ": Could not write to TensorArray index ", index, " because the value dtype is ", DataTypeString(value_t->dtype()), " but TensorArray dtype is ", DataTypeString(dtype_), "."); } if (!element_shape_.IsCompatibleWith(value_t->shape())) { return errors::InvalidArgument( - "TensorArray ", handle_.vec()(1), + "TensorArray ", handle_.vec()(1), ": Could not write to TensorArray index ", index, " because the value shape is ", value_t->shape().DebugString(), " which is incompatible with the TensorArray's inferred element " @@ -482,13 +482,13 @@ Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, } if (t.read) { - return errors::InvalidArgument("TensorArray ", handle_.vec()(1), + return errors::InvalidArgument("TensorArray ", handle_.vec()(1), ": Could not write to TensorArray index ", index, " because it has already been read."); } if (!multiple_writes_aggregate_ && t.written) { - return errors::InvalidArgument("TensorArray ", handle_.vec()(1), + return errors::InvalidArgument("TensorArray ", handle_.vec()(1), ": Could not write to TensorArray index ", index, " because it has already been written to."); @@ -500,7 +500,7 @@ Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, // Check that value_t shape matches t.shape if (value_t->shape() != t.shape) { return errors::InvalidArgument( - "TensorArray ", handle_.vec()(1), + "TensorArray ", handle_.vec()(1), ": Could not aggregate to TensorArray index ", index, " because the existing shape is ", t.shape.DebugString(), " but the new input shape is ", value_t->shape().DebugString(), "."); @@ -568,7 +568,7 @@ Status TensorArray::LockedRead(OpKernelContext* ctx, const int32 index, element_shape = tensors_[index].shape; } else if (!element_shape_.IsFullyDefined()) { return errors::InvalidArgument( - "TensorArray ", handle_.vec()(1), + "TensorArray ", handle_.vec()(1), ": Could not read from TensorArray index ", index, ". Furthermore, the element shape is not fully defined: ", element_shape_.DebugString(), @@ -598,7 +598,7 @@ Status TensorArray::LockedRead(OpKernelContext* ctx, const int32 index, TensorAndState& t = tensors_[index]; if (t.cleared) { - return errors::InvalidArgument("TensorArray ", handle_.vec()(1), + return errors::InvalidArgument("TensorArray ", handle_.vec()(1), ": Could not read index ", index, " twice because it was cleared after a " "previous read (perhaps try setting " diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index d5c9470cc89..52162e94650 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -65,7 +65,7 @@ Status GetHandle(OpKernelContext* ctx, string* container, string* ta_handle) { "Tensor array handle must be 2-element vector, but had shape: ", tensor.shape().DebugString()); } - auto h = tensor.flat(); + auto h = tensor.flat(); *container = h(0); *ta_handle = h(1); } @@ -194,7 +194,7 @@ class TensorArrayOp : public TensorArrayCreationOp { return errors::InvalidArgument("Size should be >= 0."); } - auto handle = tensor_array_output_handle->flat(); + auto handle = tensor_array_output_handle->flat(); string unique_tensor_array_name = strings::StrCat(tensor_array_name_, "_", TensorArray::tensor_array_counter.fetch_add(1)); @@ -301,7 +301,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp { string(StringPiece(resource.name()).substr(container.size())); } - auto output_handle = tensor_array_output_handle->flat(); + auto output_handle = tensor_array_output_handle->flat(); output_handle(0) = "_tensor_array_grads"; output_handle(1) = strings::StrCat(tensor_array_name, "@", source_); diff --git a/tensorflow/core/kernels/tensor_forest/resource_ops.cc b/tensorflow/core/kernels/tensor_forest/resource_ops.cc index c225d83674f..0c7b9e91263 100644 --- a/tensorflow/core/kernels/tensor_forest/resource_ops.cc +++ b/tensorflow/core/kernels/tensor_forest/resource_ops.cc @@ -34,7 +34,7 @@ class TensorForestCreateTreeVariableOp : public OpKernel { auto* const result = new TensorForestTreeResource(); - if (!result->InitFromSerialized(tree_config_t->scalar()())) { + if (!result->InitFromSerialized(tree_config_t->scalar()())) { result->Unref(); OP_REQUIRES(context, false, errors::InvalidArgument("Unable to parse tree config.")); @@ -63,7 +63,7 @@ class TensorForestTreeSerializeOp : public OpKernel { Tensor* output_config_t = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, TensorShape(), &output_config_t)); - output_config_t->scalar()() = + output_config_t->scalar()() = decision_tree_resource->decision_tree().SerializeAsString(); } }; @@ -86,7 +86,7 @@ class TensorForestTreeDeserializeOp : public OpKernel { decision_tree_resource->Reset(); if (!decision_tree_resource->InitFromSerialized( - tree_config_t->scalar()())) { + tree_config_t->scalar()())) { OP_REQUIRES(context, false, errors::InvalidArgument("Unable to parse tree config.")); } diff --git a/tensorflow/core/kernels/unicode_ops.cc b/tensorflow/core/kernels/unicode_ops.cc index 59ebbedcd7f..0bb5f0f7ef6 100644 --- a/tensorflow/core/kernels/unicode_ops.cc +++ b/tensorflow/core/kernels/unicode_ops.cc @@ -295,10 +295,10 @@ class UnicodeTranscodeOp : public OpKernel { } else { OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(), &output_tensor)); - output_tensor->flat() = input_tensor->flat(); + output_tensor->flat() = input_tensor->flat(); } - auto output_flat = output_tensor->flat(); + auto output_flat = output_tensor->flat(); bool found_any_format_error = false; for (size_t i = 0; i < output_flat.size(); ++i) { Transcode(&(output_flat(i)), input_encoder->converter_, @@ -404,7 +404,7 @@ class UnicodeDecodeBaseOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); // Go through all the strings in `input`. - const auto& input_vec = input_tensor->flat(); + const auto& input_vec = input_tensor->flat(); std::unique_ptr input_encoder = absl::make_unique(); @@ -538,7 +538,7 @@ class UnicodeEncodeOp : public OpKernel { Tensor* output_tensor; OP_REQUIRES_OK(context, context->allocate_output("output", output_shape, &output_tensor)); - auto output_tensor_flat = output_tensor->flat(); + auto output_tensor_flat = output_tensor->flat(); // Use a single index over the flattened input values tensor. int idx = 0; diff --git a/tensorflow/core/kernels/unsorted_segment_join_op.cc b/tensorflow/core/kernels/unsorted_segment_join_op.cc index 4ab890c44bd..f0b9388f7cf 100644 --- a/tensorflow/core/kernels/unsorted_segment_join_op.cc +++ b/tensorflow/core/kernels/unsorted_segment_join_op.cc @@ -115,9 +115,9 @@ class UnsortedSegmentJoinOp : public OpKernel { &output_tensor)); // Preprating flat tensors. - auto output_flat = output_tensor->flat(); + auto output_flat = output_tensor->flat(); auto flat_segment_id = segment_id.flat(); - auto flat_input = input.flat(); + auto flat_input = input.flat(); for (int i = 0; i < flat_segment_id.size(); i++) { OP_REQUIRES( diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc index b617b76a508..1e3b7fd6b30 100644 --- a/tensorflow/core/kernels/whole_file_read_ops.cc +++ b/tensorflow/core/kernels/whole_file_read_ops.cc @@ -135,14 +135,14 @@ class WriteFileOp : public OpKernel { errors::InvalidArgument( "Contents tensor must be scalar, but had shape: ", contents_input->shape().DebugString())); - const string& filename = filename_input->scalar()(); + const string& filename = filename_input->scalar()(); const string dir(io::Dirname(filename)); if (!context->env()->FileExists(dir).ok()) { OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir)); } OP_REQUIRES_OK(context, WriteStringToFile(context->env(), filename, - contents_input->scalar()())); + contents_input->scalar()())); } }; diff --git a/tensorflow/core/kernels/word2vec_kernels.cc b/tensorflow/core/kernels/word2vec_kernels.cc index 3477445197a..42b70e92bab 100644 --- a/tensorflow/core/kernels/word2vec_kernels.cc +++ b/tensorflow/core/kernels/word2vec_kernels.cc @@ -209,14 +209,14 @@ class SkipgramOp : public OpKernel { vocab_size_ = static_cast(1 + ordered.size()); Tensor word(DT_STRING, TensorShape({vocab_size_})); Tensor freq(DT_INT32, TensorShape({vocab_size_})); - word.flat()(0) = "UNK"; + word.flat()(0) = "UNK"; static const int32 kUnkId = 0; std::unordered_map word_id; int64 total_counted = 0; for (std::size_t i = 0; i < ordered.size(); ++i) { const auto& w = ordered[i].first; auto id = i + 1; - word.flat()(id) = w; + word.flat()(id) = w; auto word_count = ordered[i].second; freq.flat()(id) = word_count; total_counted += word_count; diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 4034e97e010..3f0f0c28569 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -3398,7 +3398,7 @@ REGISTER_OP("Fingerprint") return errors::InvalidArgument("`method` must be rank 0: ", method->shape()); } - const string& method_string = method->scalar()(); + const string& method_string = method->scalar()(); if (method_string != "farmhash64") { return errors::InvalidArgument("Unsupported method: ", method_string); } diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc index 1f2edee9054..e2078e001e4 100644 --- a/tensorflow/core/ops/io_ops.cc +++ b/tensorflow/core/ops/io_ops.cc @@ -101,7 +101,7 @@ REGISTER_OP("RestoreV2") const Tensor* shape_and_slices_tensor = c->input_tensor(2); if (shape_and_slices_tensor) { const auto& shape_and_slices_flat = - shape_and_slices_tensor->flat(); + shape_and_slices_tensor->flat(); if (shape_and_slices_flat.size() != c->num_outputs()) { return errors::InvalidArgument( "The number of shape_and_slice doesn't match tensor outputs."); @@ -222,7 +222,7 @@ REGISTER_OP("RestoreSlice") const Tensor* shape_and_slices_tensor = c->input_tensor(2); if (shape_and_slices_tensor) { const auto& shape_and_slice = - shape_and_slices_tensor->flat()(0); + shape_and_slices_tensor->flat()(0); if (shape_and_slice.empty()) { c->set_output(0, c->UnknownShape()); } else { diff --git a/tensorflow/core/summary/summary_db_writer.cc b/tensorflow/core/summary/summary_db_writer.cc index b203d439ccf..1a9bd33a110 100644 --- a/tensorflow/core/summary/summary_db_writer.cc +++ b/tensorflow/core/summary/summary_db_writer.cc @@ -676,7 +676,7 @@ class SeriesWriter { const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) { if (t.dtype() == DT_STRING) { if (t.dims() == 0) { - return Update(db, step, computed_time, t, t.scalar()(), rowid); + return Update(db, step, computed_time, t, t.scalar()(), rowid); } else { SqliteTransaction txn(*db); TF_RETURN_IF_ERROR( @@ -735,7 +735,7 @@ class SeriesWriter { )sql"; SqliteStatement inserter; TF_RETURN_IF_ERROR(db->Prepare(inserter_sql, &inserter)); - auto flat = t.flat(); + auto flat = t.flat(); for (int64 i = 0; i < flat.size(); ++i) { inserter.BindInt(1, tensor_rowid); inserter.BindInt(2, i); @@ -751,7 +751,7 @@ class SeriesWriter { unflushed_bytes_ = 0; if (t.dtype() == DT_STRING) { if (t.dims() == 0) { - TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar()().size())); + TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar()().size())); } else { TF_RETURN_IF_ERROR(ReserveTensors(db, &txn, kReserveMinBytes)); } @@ -1106,9 +1106,9 @@ class SummaryDbWriter : public SummaryWriterInterface { // See tensorboard/plugins/image/summary.py and data_compat.py Tensor t{DT_STRING, {3}}; auto img = s->mutable_image(); - t.flat()(0) = strings::StrCat(img->width()); - t.flat()(1) = strings::StrCat(img->height()); - t.flat()(2) = std::move(*img->mutable_encoded_image_string()); + t.flat()(0) = strings::StrCat(img->width()); + t.flat()(1) = strings::StrCat(img->height()); + t.flat()(2) = std::move(*img->mutable_encoded_image_string()); int64 tag_id; PatchPluginName(s->mutable_metadata(), kImagePluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), @@ -1120,8 +1120,8 @@ class SummaryDbWriter : public SummaryWriterInterface { // See tensorboard/plugins/audio/summary.py and data_compat.py Tensor t{DT_STRING, {1, 2}}; auto wav = s->mutable_audio(); - t.flat()(0) = std::move(*wav->mutable_encoded_audio_string()); - t.flat()(1) = ""; + t.flat()(0) = std::move(*wav->mutable_encoded_audio_string()); + t.flat()(1) = ""; int64 tag_id; PatchPluginName(s->mutable_metadata(), kAudioPluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), diff --git a/tensorflow/core/summary/summary_file_writer_test.cc b/tensorflow/core/summary/summary_file_writer_test.cc index 41060d7fe64..932ae80ab09 100644 --- a/tensorflow/core/summary/summary_file_writer_test.cc +++ b/tensorflow/core/summary/summary_file_writer_test.cc @@ -109,7 +109,7 @@ TEST_F(SummaryFileWriterTest, WriteTensor) { "string_tensor_test", [](SummaryWriterInterface* writer) { Tensor hello(DT_STRING, TensorShape({})); - hello.scalar()() = "hello"; + hello.scalar()() = "hello"; TF_RETURN_IF_ERROR(writer->WriteTensor( 2, hello, "name", SummaryMetadata().SerializeAsString())); TF_RETURN_IF_ERROR(writer->Flush()); diff --git a/tensorflow/core/util/batch_util.cc b/tensorflow/core/util/batch_util.cc index e1c32cd0069..3d704c4cdee 100644 --- a/tensorflow/core/util/batch_util.cc +++ b/tensorflow/core/util/batch_util.cc @@ -107,8 +107,8 @@ void HandleSliceToElement(Tensor* parent, Tensor* element, int64 index, template <> void HandleSliceToElement(Tensor* parent, Tensor* element, int64 index, bool can_move) { - auto parent_as_matrix = parent->flat_outer_dims(); - auto element_flat = element->flat(); + auto parent_as_matrix = parent->flat_outer_dims(); + auto element_flat = element->flat(); if (can_move) { for (int64 i = 0; i < element->NumElements(); ++i) { element_flat(i) = std::move(parent_as_matrix(index, i)); diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index 06179d49a8b..4e49d4e6329 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -852,8 +852,8 @@ Status FastParseSerializedExample( break; } case DT_STRING: { - std::copy_n(in.flat().data(), num_elements, - out.flat().data() + offset); + std::copy_n(in.flat().data(), num_elements, + out.flat().data() + offset); break; } default: @@ -1194,7 +1194,7 @@ Status FastParseExample(const Config& config, } case DT_STRING: { std::move(buffer.bytes_list.begin(), buffer.bytes_list.end(), - values->flat().data() + offset); + values->flat().data() + offset); break; } default: @@ -1578,7 +1578,7 @@ Status FastParseSingleExample(const Config& config, case DT_STRING: { *out = Tensor(out_dtype, out_shape); CopyOrMoveBlock(bytes_list.begin(), bytes_list.end(), - out->flat().data()); + out->flat().data()); break; } default: @@ -2079,7 +2079,7 @@ Status FastParseSequenceExample( int64* out_int64 = nullptr; switch (dtype) { case DT_STRING: - out_bytes = context_result->dense_values[t].flat().data(); + out_bytes = context_result->dense_values[t].flat().data(); break; case DT_FLOAT: out_float = context_result->dense_values[t].flat().data(); @@ -2113,7 +2113,7 @@ Status FastParseSequenceExample( size_t num = 0; switch (dtype) { case DT_STRING: - in_bytes = c.default_value.flat().data(); + in_bytes = c.default_value.flat().data(); num = c.default_value.NumElements(); for (int p = 0; p < num; p++) { *out_bytes++ = *in_bytes++; @@ -2190,7 +2190,7 @@ Status FastParseSequenceExample( int64* out_int64 = nullptr; switch (dtype) { case DT_STRING: - out_bytes = context_result->sparse_values[t].flat().data(); + out_bytes = context_result->sparse_values[t].flat().data(); break; case DT_FLOAT: out_float = context_result->sparse_values[t].flat().data(); @@ -2281,7 +2281,7 @@ Status FastParseSequenceExample( int64* out_int64 = nullptr; switch (dtype) { case DT_STRING: - out_bytes = feature_list_result->dense_values[t].flat().data(); + out_bytes = feature_list_result->dense_values[t].flat().data(); break; case DT_FLOAT: out_float = feature_list_result->dense_values[t].flat().data(); @@ -2392,7 +2392,8 @@ Status FastParseSequenceExample( int64* out_int64 = nullptr; switch (dtype) { case DT_STRING: - out_bytes = feature_list_result->sparse_values[t].flat().data(); + out_bytes = + feature_list_result->sparse_values[t].flat().data(); break; case DT_FLOAT: out_float = feature_list_result->sparse_values[t].flat().data(); diff --git a/tensorflow/core/util/example_proto_helper_test.cc b/tensorflow/core/util/example_proto_helper_test.cc index 1bf430b2c78..141c2400e91 100644 --- a/tensorflow/core/util/example_proto_helper_test.cc +++ b/tensorflow/core/util/example_proto_helper_test.cc @@ -57,7 +57,7 @@ class SingleExampleProtoToTensorsTest : public ::testing::Test { string_dense_config.dtype = DT_STRING; string_dense_config.shape = TensorShape({1}); string_dense_config.default_value = Tensor(DT_STRING, TensorShape({1})); - string_dense_config.default_value.scalar()() = "default"; + string_dense_config.default_value.scalar()() = "default"; dense_vec_.push_back(string_dense_config); // Setup sparse feature configuration. @@ -115,7 +115,7 @@ TEST_F(SingleExampleProtoToTensorsTest, SparseOnlyTrivial) { const std::vector& string_tensor_vec = output_sparse_values_tmp[2]; EXPECT_EQ(1, string_tensor_vec.size()); - EXPECT_EQ("forty-two", string_tensor_vec[0].vec()(0)); + EXPECT_EQ("forty-two", string_tensor_vec[0].vec()(0)); } TEST_F(SingleExampleProtoToTensorsTest, SparseOnlyEmpty) { @@ -143,7 +143,7 @@ TEST_F(SingleExampleProtoToTensorsTest, SparseOnlyEmpty) { const std::vector& string_tensor_vec = output_sparse_values_tmp[2]; EXPECT_EQ(1, string_tensor_vec.size()); - EXPECT_EQ(0, string_tensor_vec[0].vec().size()); + EXPECT_EQ(0, string_tensor_vec[0].vec().size()); } TEST_F(SingleExampleProtoToTensorsTest, DenseOnlyTrivial) { @@ -182,8 +182,8 @@ TEST_F(SingleExampleProtoToTensorsTest, DenseOnlyTrivial) { EXPECT_EQ(1, float_dense_output.matrix().size()); EXPECT_NEAR(4.2, float_dense_output.matrix()(0, 0), 0.001); - EXPECT_EQ(1, str_dense_output.matrix().size()); - EXPECT_EQ("forty-two", str_dense_output.matrix()(0, 0)); + EXPECT_EQ(1, str_dense_output.matrix().size()); + EXPECT_EQ("forty-two", str_dense_output.matrix()(0, 0)); } TEST_F(SingleExampleProtoToTensorsTest, DenseOnlyDefaults) { @@ -211,8 +211,8 @@ TEST_F(SingleExampleProtoToTensorsTest, DenseOnlyDefaults) { EXPECT_EQ(1, float_dense_output.matrix().size()); EXPECT_NEAR(0.0, float_dense_output.matrix()(0, 0), 0.001); - EXPECT_EQ(1, str_dense_output.matrix().size()); - EXPECT_EQ("default", str_dense_output.matrix()(0, 0)); + EXPECT_EQ(1, str_dense_output.matrix().size()); + EXPECT_EQ("default", str_dense_output.matrix()(0, 0)); } } // namespace diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc index 5ab0a3d084e..f2faad23313 100644 --- a/tensorflow/core/util/sparse/sparse_tensor_test.cc +++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc @@ -181,7 +181,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) { Tensor vals(DT_STRING, TensorShape({N})); auto ix_t = ix.matrix(); - auto vals_t = vals.vec(); + auto vals_t = vals.vec(); vals_t = vals_c; ix_t = ix_c; @@ -362,7 +362,7 @@ TEST(SparseTensorTest, SparseTensorToDenseTensor) { Tensor vals(DT_STRING, TensorShape({N})); auto ix_t = GetSimpleIndexTensor(N, NDIM); - auto vals_t = vals.vec(); + auto vals_t = vals.vec(); ix.matrix() = ix_t; @@ -402,7 +402,7 @@ TEST(SparseTensorTest, SparseTensorToLargerDenseTensor) { Tensor vals(DT_STRING, TensorShape({N})); auto ix_t = GetSimpleIndexTensor(N, NDIM); - auto vals_t = vals.vec(); + auto vals_t = vals.vec(); ix.matrix() = ix_t; @@ -540,7 +540,7 @@ TEST(SparseTensorTest, Concat) { auto ix_c = GetSimpleIndexTensor(N, NDIM); auto ix_t = ix.matrix(); - auto vals_t = vals.vec(); + auto vals_t = vals.vec(); ix_t = ix_c; @@ -561,7 +561,7 @@ TEST(SparseTensorTest, Concat) { TF_EXPECT_OK(concatted.IndicesValid()); auto conc_ix_t = concatted.indices().matrix(); - auto conc_vals_t = concatted.values().vec(); + auto conc_vals_t = concatted.values().vec(); for (int n = 0; n < 4; ++n) { for (int i = 0; i < N; ++i) { @@ -750,7 +750,7 @@ static void BM_SparseReorderString(int iters, int N32, int NDIM32) { TensorShape shape; std::vector order; auto ix_t = ix.matrix(); - auto vals_t = vals.vec(); + auto vals_t = vals.vec(); for (int i = 0; i < N32; ++i) { int len = rnd.Rand32() % 1000; vals_t(i).resize(len); diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc index 5d1386c26d7..550d5babcf7 100644 --- a/tensorflow/python/framework/test_ops.cc +++ b/tensorflow/python/framework/test_ops.cc @@ -96,13 +96,13 @@ class KernelLabelOp : public OpKernel { ctx->allocate_output("result", TensorShape({}), &output)); switch (KL) { case DEFAULT_LABEL: - output->scalar()() = "My label is: default"; + output->scalar()() = "My label is: default"; break; case OVERLOAD_1_LABEL: - output->scalar()() = "My label is: overload_1"; + output->scalar()() = "My label is: overload_1"; break; case OVERLOAD_2_LABEL: - output->scalar()() = "My label is: overload_2"; + output->scalar()() = "My label is: overload_2"; break; } } @@ -676,7 +676,7 @@ class DevicePlacementOp : public OpKernel { Tensor* output; OP_REQUIRES_OK(ctx, ctx->allocate_output("device", TensorShape({}), &output)); - output->scalar()() = ctx->device()->name(); + output->scalar()() = ctx->device()->name(); } }; diff --git a/tensorflow/python/kernel_tests/ackermann_op.cc b/tensorflow/python/kernel_tests/ackermann_op.cc index d42ca6f662e..2d885b7a0f0 100644 --- a/tensorflow/python/kernel_tests/ackermann_op.cc +++ b/tensorflow/python/kernel_tests/ackermann_op.cc @@ -35,7 +35,7 @@ class AckermannOp : public OpKernel { Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output_tensor)); - auto output = output_tensor->scalar(); + auto output = output_tensor->scalar(); output() = "A(m, 0) == A(m-1, 1)"; } diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index e5187ab8727..7ebba437e4c 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -101,7 +101,7 @@ void CreateTensorsFromInputInfo( if (!input.initialization_values.empty()) { LOG(FATAL) << "Initialization values are not supported for strings"; } - auto type_tensor = input_tensor.flat(); + auto type_tensor = input_tensor.flat(); type_tensor = type_tensor.constant(""); break; } diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc index 49e5cca461f..cc4078dfb85 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc @@ -126,7 +126,7 @@ Status ObtainTensorSlice(const GraphDef& input_graph_def, if (node.name() == tensor_names_node) { Tensor tensor_names_tensor; TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor)); - const auto& tensor_names_value = tensor_names_tensor.flat(); + const auto& tensor_names_value = tensor_names_tensor.flat(); for (int i = 0; i < tensor_names_value.size(); i++) { if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) { offset = i; @@ -144,7 +144,7 @@ Status ObtainTensorSlice(const GraphDef& input_graph_def, Tensor shape_and_slices_tensor; TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor)); const auto& shape_and_slices_value = - shape_and_slices_tensor.flat(); + shape_and_slices_tensor.flat(); *shape_slice_string = shape_and_slices_value(offset); return Status::OK(); } diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc index b8d6ba00de8..dfe8fb0e32b 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc @@ -116,7 +116,7 @@ class SparsifyGatherTest : public ::testing::Test { NodeDef* tensor_shapes_slices_node = CreateNode( "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def); Tensor shapes_slices_val(DT_STRING, TensorShape({1})); - shapes_slices_val.flat()(0) = "4 1 0,4:0,1"; + shapes_slices_val.flat()(0) = "4 1 0,4:0,1"; SetNodeTensorAttr("value", shapes_slices_val, tensor_shapes_slices_node); @@ -327,8 +327,8 @@ class SparsifyGatherTest : public ::testing::Test { NodeDef* tensor_shapes_slices_node = CreateNode( "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def); Tensor shapes_slices_val(DT_STRING, TensorShape({2})); - shapes_slices_val.flat()(0) = "4 1 0,4:0,1"; - shapes_slices_val.flat()(1) = "4 1 0,4:0,1"; + shapes_slices_val.flat()(0) = "4 1 0,4:0,1"; + shapes_slices_val.flat()(1) = "4 1 0,4:0,1"; SetNodeTensorAttr("value", shapes_slices_val, tensor_shapes_slices_node);