Back to Gradio

Load

scripts/load_test/load.ipynb

3.41.010.4 KB
Original Source
python
import json
import time
import uuid

import requests

!pip install -q websocket-client
import websocket
python
URL = "18.236.68.146"
python
from concurrent.futures import ThreadPoolExecutor


def run_in_parallel(func, n):
    # Ensure the callable and repetitions are valid
    if not callable(func) or not isinstance(n, int) or n < 1:
        raise ValueError("Invalid function or number of repetitions")

    # Define a wrapper function to execute
    def task_wrapper():
        return func()

    # Use ThreadPoolExecutor to run tasks in parallel
    with ThreadPoolExecutor(max_workers=n) as executor:
        futures = [executor.submit(task_wrapper) for _ in range(n)]

        # Wait for all futures to complete and collect results
        results = [future.result() for future in futures]

    return results

Gradio 4

python
def request():
    start_time = time.time()
    session_hash = uuid.uuid4().hex
    payload = {"data": ["test"], "fn_index": 0, "session_hash": session_hash}
    url = f"http://{URL}/"
    resp = requests.post(f"{url}queue/join", json=payload, timeout=5)
    assert resp.status_code == 200

    message_count = 0
    output = ""
    with requests.get(
        f"{url}queue/data?session_hash={session_hash}", stream=True
    ) as response:
        response.raise_for_status()
        for line in response.iter_lines():
            if line:
                decoded_line = line.decode("utf-8")
                if decoded_line.startswith("data:"):
                    data = decoded_line.replace("data: ", "")
                    if "close_stream" in data:
                        break
                    output = data
                    message_count += 1

    end_time = time.time()
    duration = end_time - start_time
    return (duration, message_count, json.loads(output)["output"]["data"])
python
request()
python
output = run_in_parallel(request, 5)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 25)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 100)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)

print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 250)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)

Gradio 3.x

python
def request():
    start_time = time.time()
    url = f"ws://{URL}/"
    session_hash = uuid.uuid4().hex
    output = None
    message_count = 0
    start_time = time.time()
    try:
        # Connect to WebSocket server
        ws = websocket.create_connection(f"{url}queue/join")

        while True:
            message = ws.recv()  # Wait and receive incoming message
            message_count += 1
            message = json.loads(message)
            msg = message["msg"]

            if msg == "send_hash":
                ws.send(json.dumps({"session_hash": session_hash, "fn_index": 0}))

            if msg == "send_hash":
                ws.send(
                    json.dumps(
                        {
                            "data": ["test"],
                            "event_data": None,
                            "fn_index": 0,
                            "session_hash": session_hash,
                        }
                    )
                )

            if msg == "process_completed":
                output = message["output"]["data"]
                break

    finally:
        ws.close()  # Ensure the connection is closed properly

    duration = time.time() - start_time
    return duration, message_count, output
python
request()
python
output = run_in_parallel(request, 5)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 25)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 100)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 250)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)

Simple SSE

python
def request():
    start_time = time.time()
    url = f"http://{URL}/"
    message_count = 0
    output = ""
    with requests.get(f"{url}sse", stream=True) as response:
        response.raise_for_status()
        for line in response.iter_lines():
            if line:
                decoded_line = line.decode("utf-8")
                if decoded_line.startswith("data:"):
                    output = decoded_line.replace("data: ", "")
                    message_count += 1
            if message_count == 500:
                break

    end_time = time.time()
    duration = end_time - start_time
    return (duration, message_count, output)
python
request()
python
output = run_in_parallel(request, 5)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 25)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 100)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 250)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)

Simple Websocket

python
def request():
    start_time = time.time()
    url = f"ws://{URL}/"
    output = None
    message_count = 0
    start_time = time.time()
    try:
        # Connect to WebSocket server
        ws = websocket.create_connection(f"{url}ws")

        while True:
            message = ws.recv()  # Wait and receive incoming message
            message_count += 1
            output = message
            if message_count == 500:
                break

    finally:
        ws.close()  # Ensure the connection is closed properly

    duration = time.time() - start_time
    return duration, message_count, output
python
request()
python
output = run_in_parallel(request, 5)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 25)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 100)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 250)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)

SSE w/ Workers

python
def request():
    start_time = time.time()
    payload = {"data": "test"}
    url = f"http://{URL}/"
    resp = requests.post(f"{url}sse/send", json=payload, timeout=5)
    assert resp.status_code == 200
    session_id = resp.json()["session_id"]

    message_count = 0
    output = ""
    with requests.get(
        f"{url}sse/listen?session_id={session_id}", stream=True
    ) as response:
        response.raise_for_status()
        for line in response.iter_lines():
            if line:
                decoded_line = line.decode("utf-8")
                if decoded_line.startswith("data:"):
                    output = decoded_line.replace("data: ", "")
                    message_count += 1

    end_time = time.time()
    duration = end_time - start_time
    return (duration, message_count, output)
python
request()
python
output = run_in_parallel(request, 5)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 25)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 100)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 250)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)

Websockets w/ Workers

python
def request():
    start_time = time.time()
    url = f"ws://{URL}/"
    output = None
    message_count = 0
    start_time = time.time()
    try:
        ws = websocket.create_connection(f"{url}ws")
        ws.send("test")

        while True:
            message = ws.recv()  # Wait and receive incoming message
            message_count += 1
            output = message
            if message_count == 500:
                break

    finally:
        ws.close()  # Ensure the connection is closed properly

    duration = time.time() - start_time
    return duration, message_count, output
python
request()
python
output = run_in_parallel(request, 5)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 25)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 100)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
python
output = run_in_parallel(request, 250)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)