Move bitcast registration into C
This change makes the bitcast op use the new C op API. The purpose of this change is to demonstrate that op registration and shape inference can be performed using an ABI-stable API. This is a prerequisite to being able to load ops from external libraries in the modular TF design. PiperOrigin-RevId: 255720454
This commit is contained in:
parent
5fbe995c00
commit
05ad600619
@ -46,7 +46,7 @@ filegroup(
|
||||
"*test*",
|
||||
],
|
||||
) + [
|
||||
"//tensorflow/cc:srcs",
|
||||
"//tensorflow/cc:srcs_no_runtime",
|
||||
"//tensorflow/core/distributed_runtime:server_lib.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
@ -202,6 +202,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -389,6 +390,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
@ -2,6 +2,7 @@ load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
"tf_kernel_library",
|
||||
"tf_gen_op_libs",
|
||||
)
|
||||
|
||||
package(
|
||||
@ -23,6 +24,17 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["bitcast"],
|
||||
deps = [
|
||||
"//tensorflow/c:ops",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "bitcast_op_test",
|
||||
srcs = ["bitcast_op_test.cc"],
|
||||
@ -44,9 +56,14 @@ tf_cc_test(
|
||||
#
|
||||
# LINT.IfChange
|
||||
filegroup(
|
||||
name = "android_all_ops",
|
||||
name = "android_all_op_kernels",
|
||||
srcs = [
|
||||
"bitcast_op.cc",
|
||||
],
|
||||
)
|
||||
# LINT.ThenChange(//tensorflow/contrib/makefile/tf_op_files.txt)
|
||||
|
||||
filegroup(
|
||||
name = "android_all_ops",
|
||||
srcs = ["ops/bitcast.cc"],
|
||||
)
|
||||
|
@ -136,7 +136,7 @@ static void BitcastOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
|
||||
TF_DeleteTensor(tensor);
|
||||
}
|
||||
|
||||
void RegisterBitcastOp() {
|
||||
void RegisterBitcastOpKernel() {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
{
|
||||
auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_CPU,
|
||||
@ -163,9 +163,9 @@ void RegisterBitcastOp() {
|
||||
|
||||
// A dummy static variable initialized by a lambda whose side-effect is to
|
||||
// register the bitcast kernel.
|
||||
static bool BitcastOpIsRegistered = []() {
|
||||
static bool IsBitcastOpKernelRegistered = []() {
|
||||
if (SHOULD_REGISTER_OP_KERNEL("BitcastOp")) {
|
||||
RegisterBitcastOp();
|
||||
RegisterBitcastOpKernel();
|
||||
}
|
||||
return true;
|
||||
}();
|
||||
|
@ -15,8 +15,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -97,5 +100,64 @@ TEST(BitcastOpTest, TestImpossibleCast) {
|
||||
TestBitcastOp(&int8_input, DT_UINT32, TensorShape(), error::INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
PartialTensorShape S(std::initializer_list<int64> dims) {
|
||||
return PartialTensorShape(dims);
|
||||
}
|
||||
|
||||
TEST(BitcastOpTest, TestShapeInference_LargerShape) {
|
||||
const OpRegistrationData* reg;
|
||||
TF_CHECK_OK(OpRegistry::Global()->LookUp("Bitcast", ®));
|
||||
OpDef op_def = reg->op_def;
|
||||
NodeDef def;
|
||||
TF_CHECK_OK(NodeDefBuilder("dummy", &op_def)
|
||||
.Attr("type", DT_INT8)
|
||||
.Attr("T", DT_INT64)
|
||||
.Input(FakeInput(DT_INT64))
|
||||
.Finalize(&def));
|
||||
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {});
|
||||
std::vector<shape_inference::ShapeHandle> input_shapes;
|
||||
TF_CHECK_OK(c.input("input", &input_shapes));
|
||||
ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0]));
|
||||
TF_CHECK_OK(reg->shape_inference_fn(&c));
|
||||
ASSERT_EQ("[3,4,8]", c.DebugString(c.output(0)));
|
||||
}
|
||||
|
||||
TEST(BitcastOpTest, TestShapeInference_SmallerShape) {
|
||||
const OpRegistrationData* reg;
|
||||
TF_CHECK_OK(OpRegistry::Global()->LookUp("Bitcast", ®));
|
||||
OpDef op_def = reg->op_def;
|
||||
NodeDef def;
|
||||
TF_CHECK_OK(NodeDefBuilder("dummy", &op_def)
|
||||
.Attr("type", DT_INT64)
|
||||
.Attr("T", DT_INT8)
|
||||
.Input(FakeInput(DT_INT8))
|
||||
.Finalize(&def));
|
||||
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4, 8})}, {}, {},
|
||||
{});
|
||||
std::vector<shape_inference::ShapeHandle> input_shapes;
|
||||
TF_CHECK_OK(c.input("input", &input_shapes));
|
||||
ASSERT_EQ("[3,4,8]", c.DebugString(input_shapes[0]));
|
||||
TF_CHECK_OK(reg->shape_inference_fn(&c));
|
||||
ASSERT_EQ("[3,4]", c.DebugString(c.output(0)));
|
||||
}
|
||||
|
||||
TEST(BitcastOpTest, TestShapeInference_SameShape) {
|
||||
const OpRegistrationData* reg;
|
||||
TF_CHECK_OK(OpRegistry::Global()->LookUp("Bitcast", ®));
|
||||
OpDef op_def = reg->op_def;
|
||||
NodeDef def;
|
||||
TF_CHECK_OK(NodeDefBuilder("dummy", &op_def)
|
||||
.Attr("type", DT_INT32)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Finalize(&def));
|
||||
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {});
|
||||
std::vector<shape_inference::ShapeHandle> input_shapes;
|
||||
TF_CHECK_OK(c.input("input", &input_shapes));
|
||||
ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0]));
|
||||
TF_CHECK_OK(reg->shape_inference_fn(&c));
|
||||
ASSERT_EQ("[3,4]", c.DebugString(c.output(0)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
135
tensorflow/c/kernels/ops/bitcast.cc
Normal file
135
tensorflow/c/kernels/ops/bitcast.cc
Normal file
@ -0,0 +1,135 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/ops.h"
|
||||
#include "tensorflow/core/framework/selective_registration.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
static void ComputeNewShape(TF_ShapeInferenceContext* ctx,
|
||||
TF_ShapeHandle* shape, size_t input_type_size,
|
||||
size_t output_type_size, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
if (input_type_size < output_type_size) {
|
||||
TF_ShapeInferenceContextWithRankAtLeast(ctx, shape, 1, shape, status);
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
TF_DimensionHandle* last_dim = TF_NewDimensionHandle();
|
||||
size_t divisor_val = output_type_size / input_type_size;
|
||||
TF_ShapeInferenceContextDim(ctx, shape, -1, last_dim);
|
||||
if (!TF_DimensionHandleValueKnown(last_dim) ||
|
||||
TF_DimensionHandleValue(last_dim) == divisor_val) {
|
||||
TF_ShapeInferenceContextSubshape(ctx, shape, 0, -1, shape, status);
|
||||
} else {
|
||||
std::ostringstream err;
|
||||
err << "Cannot bitcast due to shape. "
|
||||
<< TF_DimensionHandleValue(last_dim) << " does not match "
|
||||
<< divisor_val;
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
|
||||
}
|
||||
TF_DeleteDimensionHandle(last_dim);
|
||||
}
|
||||
} else if (input_type_size > output_type_size) {
|
||||
// Input type size is larger than output type size.
|
||||
size_t divisor_val = input_type_size / output_type_size;
|
||||
TF_ShapeHandle* extension =
|
||||
TF_ShapeInferenceContextVectorFromSize(ctx, divisor_val);
|
||||
TF_ShapeInferenceContextConcatenateShapes(ctx, shape, extension, shape,
|
||||
status);
|
||||
TF_DeleteShapeHandle(extension);
|
||||
}
|
||||
}
|
||||
|
||||
static void bitcast_shape_inference_fn(TF_ShapeInferenceContext* ctx,
|
||||
TF_Status* status) {
|
||||
TF_ShapeHandle* result = TF_NewShapeHandle();
|
||||
TF_ShapeInferenceContextGetInput(ctx, 0, result, status);
|
||||
if (TF_GetCode(status) == TF_OK &&
|
||||
!TF_ShapeInferenceContextRankKnown(ctx, result)) {
|
||||
TF_ShapeInferenceContextSetUnknownShape(ctx, status);
|
||||
TF_DeleteShapeHandle(result);
|
||||
return;
|
||||
}
|
||||
|
||||
// Find the size of the input and output data types.
|
||||
TF_DataType input_type;
|
||||
TF_DataType output_type;
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
TF_ShapeInferenceContext_GetAttrType(ctx, "T", &input_type, status);
|
||||
}
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
TF_ShapeInferenceContext_GetAttrType(ctx, "type", &output_type, status);
|
||||
}
|
||||
|
||||
size_t input_type_size;
|
||||
size_t output_type_size;
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
input_type_size = TF_DataTypeSize(input_type);
|
||||
output_type_size = TF_DataTypeSize(output_type);
|
||||
|
||||
if (input_type_size == 0 || output_type_size == 0) {
|
||||
std::ostringstream err;
|
||||
err << "Cannot bitcast type " << input_type << " to " << output_type
|
||||
<< " because one of the type sizes is zero";
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
ComputeNewShape(ctx, result, input_type_size, output_type_size, status);
|
||||
}
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
|
||||
}
|
||||
TF_DeleteShapeHandle(result);
|
||||
}
|
||||
|
||||
void RegisterBitcastOp() {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
TF_OpDefinitionBuilder* op_builder = TF_NewOpDefinitionBuilder("Bitcast");
|
||||
TF_OpDefinitionBuilderAddInput(op_builder, "input: T");
|
||||
TF_OpDefinitionBuilderAddOutput(op_builder, "output: type");
|
||||
TF_OpDefinitionBuilderAddAttr(
|
||||
op_builder,
|
||||
"T: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
|
||||
"uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
|
||||
"qint16, quint16, qint32}");
|
||||
TF_OpDefinitionBuilderAddAttr(
|
||||
op_builder,
|
||||
"type: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
|
||||
"uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
|
||||
"qint16, quint16, qint32}");
|
||||
TF_OpDefinitionBuilderSetShapeInferenceFunction(op_builder,
|
||||
&bitcast_shape_inference_fn);
|
||||
|
||||
TF_RegisterOpDefinition(op_builder, status);
|
||||
CHECK_EQ(TF_GetCode(status), TF_OK)
|
||||
<< "Bitcast op registration failed: " << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
static bool IsBitcastOpRegistered = []() {
|
||||
if (SHOULD_REGISTER_OP("Bitcast")) {
|
||||
RegisterBitcastOp();
|
||||
}
|
||||
return true;
|
||||
}();
|
@ -17,6 +17,17 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "srcs_no_runtime",
|
||||
srcs = [
|
||||
"framework/gradients.h",
|
||||
"framework/ops.h",
|
||||
"framework/scope.h",
|
||||
"framework/scope_internal.h",
|
||||
"//tensorflow/cc/saved_model:loader.h",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "srcs",
|
||||
srcs = [
|
||||
@ -485,6 +496,7 @@ tf_gen_op_wrappers_cc(
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "array_ops",
|
||||
api_def_srcs = ["//tensorflow/core/api_def:base_api_def"],
|
||||
extra_gen_deps = ["//tensorflow/c/kernels:bitcast_op_lib"],
|
||||
op_lib_names = [
|
||||
"array_ops",
|
||||
],
|
||||
|
@ -1448,6 +1448,7 @@ cc_library(
|
||||
":training_ops_op_lib",
|
||||
":user_ops_op_lib",
|
||||
":word2vec_ops",
|
||||
"//tensorflow/c/kernels:bitcast_op_lib",
|
||||
] + if_mkl([
|
||||
":mkl_array_ops_op_lib",
|
||||
":mkl_nn_ops_op_lib",
|
||||
@ -1467,6 +1468,7 @@ cc_library(
|
||||
":array_ops_op_lib",
|
||||
":framework",
|
||||
":lib",
|
||||
"//tensorflow/c/kernels:bitcast_op_lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -2083,7 +2085,7 @@ cc_library(
|
||||
|
||||
filegroup(
|
||||
name = "android_op_registrations_and_gradients",
|
||||
srcs = glob(
|
||||
srcs = ["//tensorflow/c/kernels:android_all_ops"] + glob(
|
||||
[
|
||||
"ops/**/*.cc",
|
||||
"ops/**/*.h",
|
||||
@ -3140,6 +3142,20 @@ CORE_CPU_BASE_HDRS = GRAPH_HDRS + [
|
||||
|
||||
tf_cuda_library(
|
||||
name = "core_cpu_base",
|
||||
hdrs = CORE_CPU_BASE_HDRS + ["public/session.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [":core_cpu_base_no_ops"] + if_static([
|
||||
":function_ops_op_lib",
|
||||
":functional_grad",
|
||||
":functional_ops_op_lib",
|
||||
"//tensorflow/core/kernels:bounds_check",
|
||||
"//tensorflow/core/kernels:required",
|
||||
]),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "core_cpu_base_no_ops",
|
||||
srcs = [
|
||||
"common_runtime/eval_const_tensor.cc",
|
||||
"common_runtime/scoped_allocator.cc",
|
||||
@ -3148,11 +3164,10 @@ tf_cuda_library(
|
||||
"common_runtime/graph_optimizer.h",
|
||||
"graph/graph_constructor.cc", # Depends on common_runtime.
|
||||
"graph/graph_def_builder_util.cc", # Depends on common_runtime.
|
||||
"public/session.h",
|
||||
"public/session_options.h",
|
||||
"public/version.h",
|
||||
] + CORE_CPU_BASE_HDRS,
|
||||
hdrs = CORE_CPU_BASE_HDRS,
|
||||
hdrs = CORE_CPU_BASE_HDRS + ["public/session.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":graph",
|
||||
@ -3165,14 +3180,8 @@ tf_cuda_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"//third_party/eigen3",
|
||||
] + if_static([
|
||||
":function_ops_op_lib",
|
||||
":functional_grad",
|
||||
":functional_ops_op_lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"//tensorflow/core/kernels:bounds_check",
|
||||
"//tensorflow/core/kernels:required",
|
||||
]),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
|
||||
@ -3338,6 +3347,16 @@ tf_cuda_library(
|
||||
] + if_static([":core_cpu_impl"]) + tf_protos_all() + tf_protos_grappler(),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "core_cpu_lib_no_ops",
|
||||
hdrs = CORE_CPU_LIB_HEADERS,
|
||||
deps = [
|
||||
":core_cpu_base_no_ops",
|
||||
":proto_text",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
] + tf_protos_all() + tf_protos_grappler(),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "core_cpu_internal",
|
||||
srcs = [
|
||||
@ -4925,6 +4944,7 @@ tf_cc_test(
|
||||
":test",
|
||||
":test_main",
|
||||
":testlib",
|
||||
"//tensorflow/c/kernels:bitcast_op_lib",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
|
@ -178,7 +178,7 @@ cc_library(
|
||||
hdrs = ["functions.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:core_cpu_base_no_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -6049,7 +6049,7 @@ filegroup(
|
||||
"unpack_op.cc",
|
||||
"variable_ops.cc",
|
||||
"variable_ops.h",
|
||||
"//tensorflow/c/kernels:android_all_ops",
|
||||
"//tensorflow/c/kernels:android_all_op_kernels",
|
||||
],
|
||||
)
|
||||
|
||||
@ -6392,7 +6392,7 @@ ANDROID_TEXTUAL_HDRS = [
|
||||
# registration.
|
||||
filegroup(
|
||||
name = "android_all_ops",
|
||||
srcs = ["//tensorflow/c/kernels:android_all_ops"] + glob(
|
||||
srcs = ["//tensorflow/c/kernels:android_all_op_kernels"] + glob(
|
||||
[
|
||||
"*.cc",
|
||||
"*.h",
|
||||
|
@ -2908,69 +2908,6 @@ REGISTER_OP("ExtractVolumePatches")
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
REGISTER_OP("Bitcast")
|
||||
.Input("input: T")
|
||||
.Output("output: type")
|
||||
// All supported dtypes are listed here to include qint16, quint16, uint32,
|
||||
// and uint64.
|
||||
.Attr(
|
||||
"T: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
|
||||
"uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
|
||||
"qint16, quint16, qint32}")
|
||||
.Attr(
|
||||
"type: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
|
||||
"uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
|
||||
"qint16, quint16, qint32}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle input = c->input(0);
|
||||
if (!c->RankKnown(input)) {
|
||||
// Input shape unknown.
|
||||
return shape_inference::UnknownShape(c);
|
||||
}
|
||||
|
||||
// Find the size of the input and output data types.
|
||||
DataType input_type;
|
||||
DataType output_type;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("T", &input_type));
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("type", &output_type));
|
||||
const int input_type_size = DataTypeSize(input_type);
|
||||
const int output_type_size = DataTypeSize(output_type);
|
||||
|
||||
if (input_type_size == 0 || output_type_size == 0) {
|
||||
return errors::InvalidArgument("Cannot bitcast types ",
|
||||
DataTypeString(input_type), " to ",
|
||||
DataTypeString(output_type),
|
||||
" because "
|
||||
"one of the type sizes is zero.");
|
||||
}
|
||||
|
||||
ShapeHandle new_shape;
|
||||
if (input_type_size == output_type_size) {
|
||||
// No change in size.
|
||||
new_shape = input;
|
||||
} else if (input_type_size < output_type_size) {
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 1, &new_shape));
|
||||
|
||||
int64 divisor_val = output_type_size / input_type_size;
|
||||
DimensionHandle last_dim = c->Dim(new_shape, -1);
|
||||
if (!c->ValueKnown(last_dim) || c->Value(last_dim) == divisor_val) {
|
||||
TF_RETURN_IF_ERROR(c->Subshape(new_shape, 0, -1, &new_shape));
|
||||
} else {
|
||||
return errors::InvalidArgument("Cannot bitcast due to shape. ",
|
||||
c->Value(last_dim), " does not match ",
|
||||
divisor_val);
|
||||
}
|
||||
} else {
|
||||
// Input type size is larger than output type size.
|
||||
int64 divisor_val = input_type_size / output_type_size;
|
||||
ShapeHandle extension = c->Vector(divisor_val);
|
||||
TF_RETURN_IF_ERROR(c->Concatenate(input, extension, &new_shape));
|
||||
}
|
||||
|
||||
c->set_output(0, new_shape);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("OneHot")
|
||||
.Input("indices: TI")
|
||||
.Input("depth: int32")
|
||||
|
@ -1883,6 +1883,10 @@ tf_gen_op_wrapper_private_py(
|
||||
"//tensorflow/contrib/quantization:__pkg__",
|
||||
"//tensorflow/python/kernel_tests:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/kernels:bitcast_op_lib",
|
||||
"//tensorflow/core:array_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
|
Loading…
Reference in New Issue
Block a user