Make "missing variant function" error status a bit more readable
PiperOrigin-RevId: 357657033 Change-Id: Id97aad3ec193dee78dcfe0c8dc3d9fdb4e236ab7
This commit is contained in:
parent
460d16750d
commit
603a21810f
@ -26,6 +26,26 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* VariantUnaryOpToString(VariantUnaryOp op) {
|
||||
switch (op) {
|
||||
case INVALID_VARIANT_UNARY_OP:
|
||||
return "INVALID";
|
||||
case ZEROS_LIKE_VARIANT_UNARY_OP:
|
||||
return "ZEROS_LIKE";
|
||||
case CONJ_VARIANT_UNARY_OP:
|
||||
return "CONJ";
|
||||
}
|
||||
}
|
||||
|
||||
const char* VariantBinaryOpToString(VariantBinaryOp op) {
|
||||
switch (op) {
|
||||
case INVALID_VARIANT_BINARY_OP:
|
||||
return "INVALID";
|
||||
case ADD_VARIANT_BINARY_OP:
|
||||
return "ADD";
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
|
||||
static std::unordered_set<string>* string_storage =
|
||||
new std::unordered_set<string>();
|
||||
|
@ -44,11 +44,15 @@ enum VariantUnaryOp {
|
||||
CONJ_VARIANT_UNARY_OP = 2,
|
||||
};
|
||||
|
||||
const char* VariantUnaryOpToString(VariantUnaryOp op);
|
||||
|
||||
enum VariantBinaryOp {
|
||||
INVALID_VARIANT_BINARY_OP = 0,
|
||||
ADD_VARIANT_BINARY_OP = 1,
|
||||
};
|
||||
|
||||
const char* VariantBinaryOpToString(VariantBinaryOp op);
|
||||
|
||||
enum VariantDeviceCopyDirection {
|
||||
INVALID_DEVICE_COPY_DIRECTION = 0,
|
||||
HOST_TO_DEVICE = 1,
|
||||
@ -311,9 +315,10 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
|
||||
UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
|
||||
UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
|
||||
if (unary_op_fn == nullptr) {
|
||||
return errors::Internal(
|
||||
"No unary variant unary_op function found for unary variant op enum: ",
|
||||
op, " Variant type_name: ", v.TypeName(), " for device type: ", device);
|
||||
return errors::Internal("No unary variant unary_op function found for op ",
|
||||
VariantUnaryOpToString(op),
|
||||
" Variant type_name: ", v.TypeName(),
|
||||
" for device type: ", device);
|
||||
}
|
||||
return (*unary_op_fn)(ctx, v, v_out);
|
||||
}
|
||||
@ -340,11 +345,10 @@ Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
|
||||
UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
|
||||
UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
|
||||
if (binary_op_fn == nullptr) {
|
||||
return errors::Internal(
|
||||
"No unary variant binary_op function found for binary variant op "
|
||||
"enum: ",
|
||||
op, " Variant type_name: '", a.TypeName(), "' for device type: ",
|
||||
device);
|
||||
return errors::Internal("No unary variant binary_op function found for op ",
|
||||
VariantBinaryOpToString(op),
|
||||
" Variant type_name: '", a.TypeName(),
|
||||
"' for device type: ", device);
|
||||
}
|
||||
return (*binary_op_fn)(ctx, a, b, out);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user