From 603a21810f9e129ce86c3c7cc4a6f37586f791b0 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Mon, 15 Feb 2021 23:41:21 -0800 Subject: [PATCH] Make "missing variant function" error status a bit more readable PiperOrigin-RevId: 357657033 Change-Id: Id97aad3ec193dee78dcfe0c8dc3d9fdb4e236ab7 --- .../core/framework/variant_op_registry.cc | 20 +++++++++++++++++++ .../core/framework/variant_op_registry.h | 20 +++++++++++-------- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc index aa3bdeab5e2..c63f1a336e9 100644 --- a/tensorflow/core/framework/variant_op_registry.cc +++ b/tensorflow/core/framework/variant_op_registry.cc @@ -26,6 +26,26 @@ limitations under the License. namespace tensorflow { +const char* VariantUnaryOpToString(VariantUnaryOp op) { + switch (op) { + case INVALID_VARIANT_UNARY_OP: + return "INVALID"; + case ZEROS_LIKE_VARIANT_UNARY_OP: + return "ZEROS_LIKE"; + case CONJ_VARIANT_UNARY_OP: + return "CONJ"; + } +} + +const char* VariantBinaryOpToString(VariantBinaryOp op) { + switch (op) { + case INVALID_VARIANT_BINARY_OP: + return "INVALID"; + case ADD_VARIANT_BINARY_OP: + return "ADD"; + } +} + std::unordered_set* UnaryVariantOpRegistry::PersistentStringStorage() { static std::unordered_set* string_storage = new std::unordered_set(); diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index edfb9c544c0..6095407468b 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -44,11 +44,15 @@ enum VariantUnaryOp { CONJ_VARIANT_UNARY_OP = 2, }; +const char* VariantUnaryOpToString(VariantUnaryOp op); + enum VariantBinaryOp { INVALID_VARIANT_BINARY_OP = 0, ADD_VARIANT_BINARY_OP = 1, }; +const char* VariantBinaryOpToString(VariantBinaryOp op); + enum VariantDeviceCopyDirection { INVALID_DEVICE_COPY_DIRECTION = 0, HOST_TO_DEVICE = 1, @@ -311,9 +315,10 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn = UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId()); if (unary_op_fn == nullptr) { - return errors::Internal( - "No unary variant unary_op function found for unary variant op enum: ", - op, " Variant type_name: ", v.TypeName(), " for device type: ", device); + return errors::Internal("No unary variant unary_op function found for op ", + VariantUnaryOpToString(op), + " Variant type_name: ", v.TypeName(), + " for device type: ", device); } return (*unary_op_fn)(ctx, v, v_out); } @@ -340,11 +345,10 @@ Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op, UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn = UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId()); if (binary_op_fn == nullptr) { - return errors::Internal( - "No unary variant binary_op function found for binary variant op " - "enum: ", - op, " Variant type_name: '", a.TypeName(), "' for device type: ", - device); + return errors::Internal("No unary variant binary_op function found for op ", + VariantBinaryOpToString(op), + " Variant type_name: '", a.TypeName(), + "' for device type: ", device); } return (*binary_op_fn)(ctx, a, b, out); }