parent
97123eff94
commit
1ea0a18e43
@ -46,7 +46,7 @@ filegroup(
|
||||
"*test*",
|
||||
],
|
||||
) + [
|
||||
"//tensorflow/cc:srcs_no_runtime",
|
||||
"//tensorflow/cc:srcs",
|
||||
"//tensorflow/core/distributed_runtime:server_lib.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
@ -202,7 +202,6 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -390,7 +389,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
@ -2,7 +2,6 @@ load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
"tf_kernel_library",
|
||||
"tf_gen_op_libs",
|
||||
)
|
||||
|
||||
package(
|
||||
@ -24,17 +23,6 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["bitcast"],
|
||||
deps = [
|
||||
"//tensorflow/c:ops",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "bitcast_op_test",
|
||||
srcs = ["bitcast_op_test.cc"],
|
||||
|
@ -15,11 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -100,64 +97,5 @@ TEST(BitcastOpTest, TestImpossibleCast) {
|
||||
TestBitcastOp(&int8_input, DT_UINT32, TensorShape(), error::INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
PartialTensorShape S(std::initializer_list<int64> dims) {
|
||||
return PartialTensorShape(dims);
|
||||
}
|
||||
|
||||
TEST(BitcastOpTest, TestShapeInference_LargerShape) {
|
||||
const OpRegistrationData* reg;
|
||||
TF_CHECK_OK(OpRegistry::Global()->LookUp("Bitcast", ®));
|
||||
OpDef op_def = reg->op_def;
|
||||
NodeDef def;
|
||||
TF_CHECK_OK(NodeDefBuilder("dummy", &op_def)
|
||||
.Attr("type", DT_INT8)
|
||||
.Attr("T", DT_INT64)
|
||||
.Input(FakeInput(DT_INT64))
|
||||
.Finalize(&def));
|
||||
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {});
|
||||
std::vector<shape_inference::ShapeHandle> input_shapes;
|
||||
TF_CHECK_OK(c.input("input", &input_shapes));
|
||||
ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0]));
|
||||
TF_CHECK_OK(reg->shape_inference_fn(&c));
|
||||
ASSERT_EQ("[3,4,8]", c.DebugString(c.output(0)));
|
||||
}
|
||||
|
||||
TEST(BitcastOpTest, TestShapeInference_SmallerShape) {
|
||||
const OpRegistrationData* reg;
|
||||
TF_CHECK_OK(OpRegistry::Global()->LookUp("Bitcast", ®));
|
||||
OpDef op_def = reg->op_def;
|
||||
NodeDef def;
|
||||
TF_CHECK_OK(NodeDefBuilder("dummy", &op_def)
|
||||
.Attr("type", DT_INT64)
|
||||
.Attr("T", DT_INT8)
|
||||
.Input(FakeInput(DT_INT8))
|
||||
.Finalize(&def));
|
||||
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4, 8})}, {}, {},
|
||||
{});
|
||||
std::vector<shape_inference::ShapeHandle> input_shapes;
|
||||
TF_CHECK_OK(c.input("input", &input_shapes));
|
||||
ASSERT_EQ("[3,4,8]", c.DebugString(input_shapes[0]));
|
||||
TF_CHECK_OK(reg->shape_inference_fn(&c));
|
||||
ASSERT_EQ("[3,4]", c.DebugString(c.output(0)));
|
||||
}
|
||||
|
||||
TEST(BitcastOpTest, TestShapeInference_SameShape) {
|
||||
const OpRegistrationData* reg;
|
||||
TF_CHECK_OK(OpRegistry::Global()->LookUp("Bitcast", ®));
|
||||
OpDef op_def = reg->op_def;
|
||||
NodeDef def;
|
||||
TF_CHECK_OK(NodeDefBuilder("dummy", &op_def)
|
||||
.Attr("type", DT_INT32)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Finalize(&def));
|
||||
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {});
|
||||
std::vector<shape_inference::ShapeHandle> input_shapes;
|
||||
TF_CHECK_OK(c.input("input", &input_shapes));
|
||||
ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0]));
|
||||
TF_CHECK_OK(reg->shape_inference_fn(&c));
|
||||
ASSERT_EQ("[3,4]", c.DebugString(c.output(0)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -1,132 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/ops.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
static void ComputeNewShape(TF_ShapeInferenceContext* ctx,
|
||||
TF_ShapeHandle* shape, size_t input_type_size,
|
||||
size_t output_type_size, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
if (input_type_size < output_type_size) {
|
||||
TF_ShapeInferenceContextWithRankAtLeast(ctx, shape, 1, shape, status);
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
TF_DimensionHandle* last_dim = TF_NewDimensionHandle();
|
||||
size_t divisor_val = output_type_size / input_type_size;
|
||||
TF_ShapeInferenceContextDim(ctx, shape, -1, last_dim);
|
||||
if (!TF_DimensionHandleValueKnown(last_dim) ||
|
||||
TF_DimensionHandleValue(last_dim) == divisor_val) {
|
||||
TF_ShapeInferenceContextSubshape(ctx, shape, 0, -1, shape, status);
|
||||
} else {
|
||||
std::ostringstream err;
|
||||
err << "Cannot bitcast due to shape. "
|
||||
<< TF_DimensionHandleValue(last_dim) << " does not match "
|
||||
<< divisor_val;
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
|
||||
}
|
||||
TF_DeleteDimensionHandle(last_dim);
|
||||
}
|
||||
} else if (input_type_size > output_type_size) {
|
||||
// Input type size is larger than output type size.
|
||||
size_t divisor_val = input_type_size / output_type_size;
|
||||
TF_ShapeHandle* extension =
|
||||
TF_ShapeInferenceContextVectorFromSize(ctx, divisor_val);
|
||||
TF_ShapeInferenceContextConcatenateShapes(ctx, shape, extension, shape,
|
||||
status);
|
||||
TF_DeleteShapeHandle(extension);
|
||||
}
|
||||
}
|
||||
|
||||
static void bitcast_shape_inference_fn(TF_ShapeInferenceContext* ctx,
|
||||
TF_Status* status) {
|
||||
TF_ShapeHandle* result = TF_NewShapeHandle();
|
||||
TF_ShapeInferenceContextGetInput(ctx, 0, result, status);
|
||||
if (TF_GetCode(status) == TF_OK &&
|
||||
!TF_ShapeInferenceContextRankKnown(ctx, result)) {
|
||||
TF_ShapeInferenceContextSetUnknownShape(ctx, status);
|
||||
TF_DeleteShapeHandle(result);
|
||||
return;
|
||||
}
|
||||
|
||||
// Find the size of the input and output data types.
|
||||
TF_DataType input_type;
|
||||
TF_DataType output_type;
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
TF_ShapeInferenceContext_GetAttrType(ctx, "T", &input_type, status);
|
||||
}
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
TF_ShapeInferenceContext_GetAttrType(ctx, "type", &output_type, status);
|
||||
}
|
||||
|
||||
size_t input_type_size;
|
||||
size_t output_type_size;
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
input_type_size = TF_DataTypeSize(input_type);
|
||||
output_type_size = TF_DataTypeSize(output_type);
|
||||
|
||||
if (input_type_size == 0 || output_type_size == 0) {
|
||||
std::ostringstream err;
|
||||
err << "Cannot bitcast type " << input_type << " to " << output_type
|
||||
<< " because one of the type sizes is zero";
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
ComputeNewShape(ctx, result, input_type_size, output_type_size, status);
|
||||
}
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
|
||||
}
|
||||
TF_DeleteShapeHandle(result);
|
||||
}
|
||||
|
||||
static void RegisterBitcastOp() {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
TF_OpDefinitionBuilder* op_builder = TF_NewOpDefinitionBuilder("Bitcast");
|
||||
TF_OpDefinitionBuilderAddInput(op_builder, "input: T");
|
||||
TF_OpDefinitionBuilderAddOutput(op_builder, "output: type");
|
||||
TF_OpDefinitionBuilderAddAttr(
|
||||
op_builder,
|
||||
"T: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
|
||||
"uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
|
||||
"qint16, quint16, qint32}");
|
||||
TF_OpDefinitionBuilderAddAttr(
|
||||
op_builder,
|
||||
"type: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
|
||||
"uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
|
||||
"qint16, quint16, qint32}");
|
||||
TF_OpDefinitionBuilderSetShapeInferenceFunction(op_builder,
|
||||
&bitcast_shape_inference_fn);
|
||||
|
||||
TF_RegisterOpDefinition(op_builder, status);
|
||||
CHECK_EQ(TF_GetCode(status), TF_OK)
|
||||
<< "Bitcast op registration failed: " << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
static bool IsBitcastOpRegistered = []() {
|
||||
RegisterBitcastOp();
|
||||
return true;
|
||||
}();
|
@ -17,17 +17,6 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "srcs_no_runtime",
|
||||
srcs = [
|
||||
"framework/gradients.h",
|
||||
"framework/ops.h",
|
||||
"framework/scope.h",
|
||||
"framework/scope_internal.h",
|
||||
"//tensorflow/cc/saved_model:loader.h",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "srcs",
|
||||
srcs = [
|
||||
@ -496,7 +485,6 @@ tf_gen_op_wrappers_cc(
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "array_ops",
|
||||
api_def_srcs = ["//tensorflow/core/api_def:base_api_def"],
|
||||
extra_gen_deps = ["//tensorflow/c/kernels:bitcast_op_lib"],
|
||||
op_lib_names = [
|
||||
"array_ops",
|
||||
],
|
||||
|
@ -1440,7 +1440,6 @@ cc_library(
|
||||
":training_ops_op_lib",
|
||||
":user_ops_op_lib",
|
||||
":word2vec_ops",
|
||||
"//tensorflow/c/kernels:bitcast_op_lib",
|
||||
] + if_mkl([
|
||||
":mkl_array_ops_op_lib",
|
||||
":mkl_nn_ops_op_lib",
|
||||
@ -1460,7 +1459,6 @@ cc_library(
|
||||
":array_ops_op_lib",
|
||||
":framework",
|
||||
":lib",
|
||||
"//tensorflow/c/kernels:bitcast_op_lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -3141,20 +3139,6 @@ CORE_CPU_BASE_HDRS = GRAPH_HDRS + [
|
||||
|
||||
tf_cuda_library(
|
||||
name = "core_cpu_base",
|
||||
hdrs = CORE_CPU_BASE_HDRS + ["public/session.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [":core_cpu_base_no_ops"] + if_static([
|
||||
":function_ops_op_lib",
|
||||
":functional_grad",
|
||||
":functional_ops_op_lib",
|
||||
"//tensorflow/core/kernels:bounds_check",
|
||||
"//tensorflow/core/kernels:required",
|
||||
]),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "core_cpu_base_no_ops",
|
||||
srcs = [
|
||||
"common_runtime/eval_const_tensor.cc",
|
||||
"common_runtime/scoped_allocator.cc",
|
||||
@ -3163,10 +3147,11 @@ tf_cuda_library(
|
||||
"common_runtime/graph_optimizer.h",
|
||||
"graph/graph_constructor.cc", # Depends on common_runtime.
|
||||
"graph/graph_def_builder_util.cc", # Depends on common_runtime.
|
||||
"public/session.h",
|
||||
"public/session_options.h",
|
||||
"public/version.h",
|
||||
] + CORE_CPU_BASE_HDRS,
|
||||
hdrs = CORE_CPU_BASE_HDRS + ["public/session.h"],
|
||||
hdrs = CORE_CPU_BASE_HDRS,
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":graph",
|
||||
@ -3179,8 +3164,14 @@ tf_cuda_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"//third_party/eigen3",
|
||||
] + if_static([
|
||||
":function_ops_op_lib",
|
||||
":functional_grad",
|
||||
":functional_ops_op_lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"//tensorflow/core/kernels:bounds_check",
|
||||
"//tensorflow/core/kernels:required",
|
||||
]),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
|
||||
@ -3346,16 +3337,6 @@ tf_cuda_library(
|
||||
] + if_static([":core_cpu_impl"]) + tf_protos_all() + tf_protos_grappler(),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "core_cpu_lib_no_ops",
|
||||
hdrs = CORE_CPU_LIB_HEADERS,
|
||||
deps = [
|
||||
":core_cpu_base_no_ops",
|
||||
":proto_text",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
] + tf_protos_all() + tf_protos_grappler(),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "core_cpu_internal",
|
||||
srcs = [
|
||||
@ -4938,7 +4919,6 @@ tf_cc_test(
|
||||
":test",
|
||||
":test_main",
|
||||
":testlib",
|
||||
"//tensorflow/c/kernels:bitcast_op_lib",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
|
@ -178,7 +178,7 @@ cc_library(
|
||||
hdrs = ["functions.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_base_no_ops",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -2888,6 +2888,69 @@ REGISTER_OP("ExtractVolumePatches")
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
REGISTER_OP("Bitcast")
|
||||
.Input("input: T")
|
||||
.Output("output: type")
|
||||
// All supported dtypes are listed here to include qint16, quint16, uint32,
|
||||
// and uint64.
|
||||
.Attr(
|
||||
"T: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
|
||||
"uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
|
||||
"qint16, quint16, qint32}")
|
||||
.Attr(
|
||||
"type: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
|
||||
"uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
|
||||
"qint16, quint16, qint32}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle input = c->input(0);
|
||||
if (!c->RankKnown(input)) {
|
||||
// Input shape unknown.
|
||||
return shape_inference::UnknownShape(c);
|
||||
}
|
||||
|
||||
// Find the size of the input and output data types.
|
||||
DataType input_type;
|
||||
DataType output_type;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("T", &input_type));
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("type", &output_type));
|
||||
const int input_type_size = DataTypeSize(input_type);
|
||||
const int output_type_size = DataTypeSize(output_type);
|
||||
|
||||
if (input_type_size == 0 || output_type_size == 0) {
|
||||
return errors::InvalidArgument("Cannot bitcast types ",
|
||||
DataTypeString(input_type), " to ",
|
||||
DataTypeString(output_type),
|
||||
" because "
|
||||
"one of the type sizes is zero.");
|
||||
}
|
||||
|
||||
ShapeHandle new_shape;
|
||||
if (input_type_size == output_type_size) {
|
||||
// No change in size.
|
||||
new_shape = input;
|
||||
} else if (input_type_size < output_type_size) {
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 1, &new_shape));
|
||||
|
||||
int64 divisor_val = output_type_size / input_type_size;
|
||||
DimensionHandle last_dim = c->Dim(new_shape, -1);
|
||||
if (!c->ValueKnown(last_dim) || c->Value(last_dim) == divisor_val) {
|
||||
TF_RETURN_IF_ERROR(c->Subshape(new_shape, 0, -1, &new_shape));
|
||||
} else {
|
||||
return errors::InvalidArgument("Cannot bitcast due to shape. ",
|
||||
c->Value(last_dim), " does not match ",
|
||||
divisor_val);
|
||||
}
|
||||
} else {
|
||||
// Input type size is larger than output type size.
|
||||
int64 divisor_val = input_type_size / output_type_size;
|
||||
ShapeHandle extension = c->Vector(divisor_val);
|
||||
TF_RETURN_IF_ERROR(c->Concatenate(input, extension, &new_shape));
|
||||
}
|
||||
|
||||
c->set_output(0, new_shape);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("OneHot")
|
||||
.Input("indices: TI")
|
||||
.Input("depth: int32")
|
||||
|
@ -1902,10 +1902,6 @@ tf_gen_op_wrapper_private_py(
|
||||
"//tensorflow/contrib/quantization:__pkg__",
|
||||
"//tensorflow/python/kernel_tests:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/kernels:bitcast_op_lib",
|
||||
"//tensorflow/core:array_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
|
Loading…
Reference in New Issue
Block a user