Add PyTorch built-in SDPA to Optimus
Add PyTorch's core scaled dot-product attention (SDPA) to Optimus. This automatically uses flash attention 2, or memory efficient attention, if the hardware supports it. If it doesn't, falls back to manual implementation. Training should be much faster with this; memory should also be around half what it was before.
Loading
Please register or sign in to comment