Threads through a CancellationManager, and uses it to cancel all associated operations on failure.
PiperOrigin-RevId: 356806701
Change-Id: I7c41c6242f40018ec9385ed65771d08b3763365d
Throws an error if the shape of the overall tensor is queried for now. The plumbing required to make the shape information look like not-fully-defined-shape graph tensors looks very shallow if we want to go that route.
This means that querying the shape of a parallel tensor is now a blocking operation (and needs a status return) rather than creation itself blocking.
PiperOrigin-RevId: 351907155
Change-Id: I2610613efd4bb6aafa44fc78ee53824fb6020b6a
Stacks not part of proto. Moved to TF2 only test and run with TF2_BEHAVIOR env set.
PiperOrigin-RevId: 350861872
Change-Id: Id7e7f0c502e0acfd1c7d45a3ef64e4d99034a04e
Stacks not part of proto. Moved to TF2 only test and run with TF2_BEHAVIOR env set.
PiperOrigin-RevId: 350832704
Change-Id: I342bdf724a9842e9f1b3c49095d8db1ec0c56076
Now multiple devices are supported within BatchFunction. Currently, inputs and outputs must still be on the CPU, as the concatenation/splitting is done on the CPU.
PiperOrigin-RevId: 347524478
Change-Id: Ib329987bf09513570c3c260e4c0834d6102a4364
Otherwise, instantiated functions have no Python stack traces, and one has to
get them manually through FunctionDefinitionLibrary.
More importantly, it makes it impossible to assign different stack traces in
cases where it is actually required: namely, inlining.
PiperOrigin-RevId: 346435458
Change-Id: I41a3188e453566fbae6d29b261eefe4d24cf6453
Some years back, I forked the TF bfloat16 numpy extension to create a JAX version of the same NumPy extension. The TF version has not been actively maintained, whereas the JAX version is substantially more feature-complete (e.g., it implements most of the NumPy ufuncs).
However, having two different NumPy extensions that register the same type causes problems, e.g., if someone loads the (less complete) TF implementation first it takes priority over the (more complete) JAX implementation. Fix this by merging the two implementations and replacing the TF bfloat16 implementation with the JAX version.
The best case would be to go one step further and move the bfloat16 code into its own pip package that can be shared by TF and JAX (and other systems), but we leave this for future work.
A side effect of this change is that calls to numpy.testing.assert_allclose require an explicit cast to a non-bfloat16 type.
PiperOrigin-RevId: 346350783
Change-Id: Ic4d26457f9c9f50ef4c31b4adc3e938101c8e037
Move the ownership of Python stack traces to Graph object, make them accessible from C++ API
Expose stack printing options, implement common prefix filtering.
PiperOrigin-RevId: 345579757
Change-Id: I88673891e893b1f71a5b039e44f0bc30f190c18a
It was "large" and TAP filtered it so didn't run. It takes 13 sec to run.
PiperOrigin-RevId: 345130363
Change-Id: Ica00310f2b548c282c3dea42f6218e59d595ce28
Now multiple devices are supported within BatchFunction. Currently, inputs and outputs must still be on the CPU, as the concatenation/splitting is done on the CPU.
PiperOrigin-RevId: 343979298
Change-Id: Icc2bd49fe66c4dd80622c921afacea427e87ac16
These tests target features that TFRT don't have plan to support for single host training, including quantization, XLA, clusters.
PiperOrigin-RevId: 343969137
Change-Id: Ib0d3daa1f3b38545a30030b4be321ca40d15062e