third_party/xla/xla/megascale/tools/network_analysis_oss.ipynb
Provided an Xprof profile, this notebook will generate several graphs and metrics that can help in identifying and understanding potential performance issues.
Use: Upload one or more xprof profiles to use as input to this Colab.
The cells below must be run before jumping to other sections.
# @title Install necessary packages
import shutil
if shutil.which("pip") is None:
print("pip is not installed. Skipping package installation.")
else:
# The nightly release of JAX is required until the official release contains the necessary ProfileData API.
!pip install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
!pip install git+https://github.com/jax-ml/jax
print("Finished package installation.")
# @title Define helper functions
from IPython import display
import jax
import matplotlib.pyplot as plt
import pandas as pd
def bytes_to_human(n, precision=2):
"""Convert bytes to a human-readable string (e.g., B, KiB, MiB, GiB)."""
if n < 0:
return "-NaN"
# Define the units and the base for the conversion
units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB"]
base = 1024
# Special case for bytes
if n < base:
return f"{n} {units[0]}"
# Find the correct unit and divide by the base
for i, unit in enumerate(units):
if n < base ** (i + 1):
return f"{n / (base**i):.{precision}f} {unit}"
# Handle extremely large numbers
return f"{n / base**(len(units)-1):.{precision}f} {units[-1]}"
Using the "Files" button in the vertical menu on the left hand side. Upload one or more xprof profiles.
# @title Generate a DataFrame for each profile.
profile_paths = '/content/example_profile_1.pb' # @param {'type': 'string', isTemplate: true}
labels = 'profile1' # @param {'type':'string', isTemplate: true}
profile_paths = profile_paths.split(',')
labels = labels.split(',')
assert len(profile_paths) == len(labels)
dfs = []
for i, xplane_path in enumerate(profile_paths):
plane = jax.profiler.ProfileData.from_file(xplane_path).find_plane_with_name(
'/host:CPU'
)
rows = []
for line in plane.lines:
if not line.name.startswith('MegascaleEM2_Worker'):
continue
for event in line.events:
if event.name == 'MegaScale: Communication Transport Receive':
stats = dict(event.stats)
source_id = f'{stats["dcn_source_slice_id"]}-{stats["dcn_source_per_slice_device_id"]}'
destination_id = f'{stats["dcn_destination_slice_id"]}-{stats["dcn_destination_per_slice_device_id"]}'
latency_us = stats['duration_us']
start_ns = event.start_ns
end_ns = start_ns + latency_us * 1000
timestamp = pd.to_datetime(end_ns, unit='ns')
rows.append([
latency_us,
stats['payload_size_bytes'],
source_id,
destination_id,
start_ns,
end_ns,
timestamp,
])
df = pd.DataFrame(
rows,
columns=[
'latency_us',
'bytes',
'src',
'dst',
'start_ns',
'end_ns',
'timestamp',
],
)
df.set_index('timestamp', inplace=True)
df.attrs['label'] = labels[i]
dfs.append(df)
Check for outliers or persistently high latency. Small transfers should have lower latency than large ones.
Possible sources of high latency are network slowdowns or individual host slowness.
# @title Network transfer latency
for _, df in enumerate(dfs):
plt.figure(figsize=(10, 6))
series_names = df['bytes'].unique()
# Sort series_names numerically
series_names_sorted = sorted(series_names)
for i, series_name in enumerate(series_names_sorted):
series_data = df[df['bytes'] == series_name]
plt.scatter(
series_data.index,
series_data['latency_us'] / 1000,
label=f'{bytes_to_human(series_name)}',
s=20,
)
plt.title(f'Network transfer latency over time for {df.attrs.get("label")}')
plt.xlabel('Time')
plt.ylabel('Latency (ms)')
plt.legend(title='Transfer Size')
plt.grid(True)
plt.tight_layout()
plt.show()
Sanity check the transfer size distribution. Generally we want fewer larger transfers, not a high number of small ones.
# @title Distribution of transfer sizes
for _, df in enumerate(dfs):
grouped = df.groupby('bytes')
count_by_bytes = grouped.size()
data = {
'Buffer size': [],
'Count': [],
'Percentage': [],
}
for key, value in count_by_bytes.items():
percentage = (value / count_by_bytes.sum()) * 100
data['Buffer size'].append(bytes_to_human(key))
data['Count'].append(value)
data['Percentage'].append(f'{percentage:.2f}')
display.display(pd.DataFrame(data))
This indicates how many pending collectives there are at a given point in time throughout the profiling time window. If this chart is spiky or remains consistently high then the program may not be well optimized for compute/communication overlap.
# @title Inflight transfers by size.
agg_window_ms = 100
for i, df in enumerate(dfs):
grouped_by_size = df.groupby('bytes')
legend_labels = []
resampled_series_list = []
for size, group_df in grouped_by_size:
# Create time series for the start and end of each transfer for this size.
# At the start of a transfer, the inflight count increases by 1.
# At the end of a transfer, the inflight count decreases by 1.
inflight_changes = pd.concat([
pd.Series(1, index=pd.to_datetime(group_df['start_ns'], unit='ns')),
pd.Series(-1, index=pd.to_datetime(group_df['end_ns'], unit='ns')),
]).sort_index()
# Handle duplicate timestamps.
inflight_changes = (
inflight_changes.groupby(inflight_changes.index).sum().sort_index()
)
# Calculate the cumulative sum to get the number of inflight transfers over time for this size.
cumulative_inflight = inflight_changes.cumsum()
# Resample the cumulative inflight count to the desired time window and take the max.
inflight_count_per_window = (
cumulative_inflight.resample(f'{agg_window_ms}ms').max().fillna(0)
)
# Add the series.
resampled_series_list.append(inflight_count_per_window)
legend_labels.append(f'{bytes_to_human(int(size))}')
# Concatenate all resampled series into a single DataFrame and reindex to a common time index.
combined_df = pd.concat(resampled_series_list, axis=1).fillna(0)
time_labels_datetime = combined_df.index
inflight_data_for_stacking = (
combined_df.values.T
) # Transpose to get data in the correct shape for stackplot.
# Plot the results as a stacked area chart.
plt.figure(figsize=(12, 5))
plt.stackplot(
time_labels_datetime, inflight_data_for_stacking, labels=legend_labels
)
plt.title(f'Inflight Transfers By Size For {labels[i]}')
plt.xlabel('Time')
plt.ylabel('Inflight transfers')
plt.legend(title='Transfer Size (bytes)')
plt.grid(True)
plt.tight_layout()
plt.show()
Optionally, you may enter the per-task maximum bandwidth for the platform to see it in the graph. For example, if you're using TPU v6e with two tasks per machine then the per-task bandwidth is 4 NICs per machine * 200 Gbps per NIC / 2 tasks per machine = 400 Gbps.
If the throughput is consistently near the platform's theoretical max then this indicates that this workload is network-bound.
# @title Network transfer throughput over time.
max_bandwidth_gbps = 100 # @param {'type':'number', isTemplate: true}
agg_window_ms = 1 # @param {'type':'number', isTemplate: true}
for _, df in enumerate(dfs):
df['average_bandwidth_gbps'] = (
(df['bytes'] * 8) / (df['latency_us'] * 1e-6) / 1e9
)
# Create a time series for the start and end of each transfer.
# At the start of a transfer, bandwidth increases by average_bandwidth_gbps.
# At the end of a transfer, bandwidth decreases by average_bandwidth_gbps.
bandwidth_changes = pd.concat([
pd.Series(
df['average_bandwidth_gbps'].values,
index=pd.to_datetime(df['start_ns'], unit='ns'),
),
pd.Series(
-df['average_bandwidth_gbps'].values,
index=pd.to_datetime(df['end_ns'], unit='ns'),
),
]).sort_index()
# Handle duplicate timestamps.
bandwidth_changes = (
bandwidth_changes.groupby(bandwidth_changes.index).sum().sort_index()
)
# Calculate the cumulative bandwidth over time.
cumulative_bandwidth = bandwidth_changes.cumsum()
# To get a correct time-weighted average, we first create a dense time series
# by upsampling and forward-filling, then we downsample and take the mean.
# Note: The upsampling frequency should be high enough to capture the data's dynamics.
dense_bandwidth = cumulative_bandwidth.resample(
f'{agg_window_ms/100}ms'
).ffill()
average_bandwidth_per_window = (
dense_bandwidth.resample(f'{agg_window_ms}ms').mean().fillna(0)
)
# Plot the results
plt.figure(figsize=(12, 5))
plt.plot(
average_bandwidth_per_window.index, average_bandwidth_per_window.values
)
if max_bandwidth_gbps > 0:
plt.axhline(
y=max_bandwidth_gbps,
color='r',
linestyle='--',
label=f'Platform Max Bandwidth ({max_bandwidth_gbps} Gbps)',
)
plt.title(f'Network Throughput (Gbps) for {df.attrs.get("label")}')
plt.xlabel('Time')
plt.ylabel('Throughput (Gbps)')
plt.grid(True)
plt.tight_layout()
plt.show()
Returns the source and destination global device ID ({slice}-{per_slice_device_id}) pairs for transfers that occur after min_start_time_ns and that take longer than min_latency_us.
This data can help identify bad network cards, overloaded network switches, sub-optimal sharding, etc.
# @title Long tail host pairs
min_start_time_ns = 0 # @param
min_latency_us = 1000 # @param
for i, df in enumerate(dfs):
long_tail = (
df.loc[df['start_ns'] > min_start_time_ns]
.loc[df['latency_us'] > min_latency_us]
.replace({'src': i, 'dst': i})
.groupby(['src', 'dst'])
.size()
.reset_index(name='counts')
)
print(f'Long tail host src-dest pair for {labels[i]}')
display.display(long_tail)