docs/source/accelerator/profiler.md
PyTorch ships a device-agnostic profiler that instruments CPU-side operator dispatch, coordinates with accelerator collectors, captures Python stacks, and exports aggregated statistics or Chrome/Perfetto traces. For core architecture, see torch/csrc/profiler/README.md.
There are two primary integration paths for accelerators:
Legacy autograd profiler:
ProfilerStubs to record device events and compute elapsed times.Kineto-based timeline:
This document focuses on path (1): how a PrivateUse1 accelerator exposes the minimal hooks to plug into the legacy autograd profiler so ATen ops and record_function ranges are correctly attributed to device activity.
| Layer | Responsibility | Source |
|---|---|---|
| Python control plane | Owns profiler lifecycle (prepare → start → stop → step) and exposes user APIs such as torch.autograd.profiler.profile. | torch/autograd/profiler.py |
| Profiler stubs | Implements torch::profiler::impl::ProfilerStubs so the profiler can record device events, synchronize, iterate devices, and compute elapsed time. | torch/csrc/profiler/stubs/ |
| Device runtime | Provides streams, events, and device guards used by the stubs; implementation is backend-specific. | Backend extension (vendor code) |
This layering keeps PyTorch device-agnostic: Python brokers the session, ProfilerStubs translate profiler requests into backend runtime calls, and the runtime interacts with the accelerator.
record() must capture (optional) device index, allocate a backend event, optionally stash a CPU timestamp, and enqueue the event on the active stream.elapsed() is responsible for synchronizing individual events and returning durations in microseconds.synchronize() and onEachDevice() guarantee phase transitions (e.g., warmup → active) are aligned across devices.mark, rangePush, and rangePop can be implemented to enrich traces; otherwise they may be left as no-ops.Here we use OpenReg (Open Registration) to illustrate the minimal set of hooks a PrivateUse1 accelerator needs to expose so the profiler can attribute ATen ops, record_function ranges, and user code to device activity. OpenReg keeps upstream code untouched by translating profiler requests into its runtime calls, mirroring what a production accelerator would implement inside an out-of-tree extension.
OpenReg currently relies on the legacy profiler (torch.autograd.profiler.profile) interface rather than the modern one (torch.profiler.profile) because the latter enforces use_kineto=True.
torch::profiler::impl::OpenRegMethods inherits from ProfilerStubs and wires the hooks described above:
| Method | Purpose |
|---|---|
record | Grabs the current OpenRegStream, creates an orEvent, captures an optional CPU timestamp via c10::getTime(), and records the event on the stream. |
elapsed | Synchronizes both events, calls orEventElapsedTime, and converts milliseconds to microseconds for the profiler. |
onEachDevice | Uses c10::DeviceGuard(DeviceType::PrivateUse1) to iterate over torch.openreg.device_count() so schedulers can run per-device setup or teardown. |
synchronize | Calls orDeviceSynchronize() to align device work with CPU scheduling phases. |
enabled and annotation shims | Report availability and provide placeholder implementations for mark/push/pop. |
The constructor registers the methods once via registerPrivateUse1Methods(&methods);, making them discoverable whenever the profiler is enabled with use_device="openreg".
On the Python side, no new entrypoint is required—developers use the standard autograd profiler:
from torch.autograd.profiler import profile as autograd_profile
from torch.profiler import record_function
with autograd_profile(use_device="openreg", record_shapes=True) as prof:
with record_function("matmul"):
x = torch.randn(512, 512, device="openreg")
y = torch.randn(512, 512, device="openreg")
z = x @ y
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
prof.export_chrome_trace("openreg_trace.json")
autograd_profile(use_device="openreg").ProfilerState.KINETO_PRIVATEUSE1_FALLBACK.record() an event.orEvent objects, attach them to the current stream, and stash CPU timestamps.elapsed() to compute durations.