2019-05-24 04:45:48 -07:00

20 lines
494 B
Python

# Description:
# Python API for shardings in XLA.
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "xla_sharding",
srcs = ["xla_sharding.py"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto_py",
"//tensorflow/compiler/xla/python_api:types",
"//tensorflow/compiler/xla/python_api:xla_shape",
"//third_party/py/numpy",
],
)