architecture/03-prediction-api.md
The Prediction API is the HTTP interface for running model inference. It uses a fixed envelope format that wraps model-specific inputs and outputs, allowing a uniform API across all Cog models.
| Endpoint | Method | Purpose |
|---|---|---|
POST /predictions | Create | Start a new prediction |
PUT /predictions/{id} | Create (idempotent) | Start or retrieve existing prediction |
POST /predictions/{id}/cancel | Cancel | Cancel a running prediction |
GET /health-check | Health | Check server status |
GET / | Index | List available endpoints |
GET /openapi.json | Schema | OpenAPI specification |
By default, POST /predictions blocks until completion. For long-running predictions, use async mode with Prefer: respond-async header -- the response returns immediately with status processing, and progress updates are delivered via webhook.
Every Cog model exposes the same endpoints with the same request/response structure. The model-specific parts (input fields, output type) are defined by the Schema and validated at runtime.
flowchart TB
subgraph envelope ["Fixed Envelope (same for all models)"]
direction TB
fixed["id, status, created_at, logs, metrics, ..."]
input["input#colon; { ... } — model-specific (from schema)"]
output["output#colon; ... — model-specific (from schema)"]
end
This pattern means:
What clients send to start a prediction:
{
"id": "abc-123",
"input": {
"prompt": "A photo of a cat",
"steps": 50
},
"webhook": "https://example.com/webhook",
"webhook_events_filter": ["start", "output", "logs", "completed"]
}
| Field | Type | Purpose |
|---|---|---|
id | string (optional) | Client-provided ID for idempotency |
input | object | Model-specific -- validated against schema |
webhook | URL (optional) | Where to send progress updates |
webhook_events_filter | array (optional) | Which events to send |
created_at | datetime (optional) | Client-provided timestamp |
The input object is validated against the Input schema generated from the predictor's predict() signature. Unknown fields are rejected; missing required fields raise validation errors.
What comes back from the API:
{
"id": "abc-123",
"status": "succeeded",
"input": {
"prompt": "A photo of a cat",
"steps": 50
},
"output": "https://storage.example.com/output.png",
"logs": "Loading model...\nGenerating image...\nDone.",
"error": null,
"metrics": {
"predict_time": 4.52
},
"created_at": "2024-01-15T10:30:00Z",
"started_at": "2024-01-15T10:30:01Z",
"completed_at": "2024-01-15T10:30:05Z"
}
| Field | Type | Purpose |
|---|---|---|
id | string | Prediction identifier |
status | enum | starting, processing, succeeded, canceled, failed |
input | object | Echo of the input (for reference) |
output | any | Model-specific -- type defined by schema |
logs | string | Captured stdout/stderr from predict() |
error | string | Error message if status is failed |
metrics | object | Timing and other metrics |
created_at | datetime | When request was received |
started_at | datetime | When prediction began |
completed_at | datetime | When prediction finished |
stateDiagram-v2
[*] --> starting: Request received
starting --> processing: predict() called
processing --> succeeded: predict() returns
processing --> failed: predict() raises exception
processing --> canceled: Cancel requested
succeeded --> [*]
failed --> [*]
canceled --> [*]
State transitions on the Prediction struct fire webhooks as a side effect -- calling set_processing() sends the Start webhook, set_succeeded() sends the terminal Completed webhook.
Synchronous predictions automatically cancel when the client connection drops. This prevents wasted computation on predictions where the client is no longer listening.
// SyncPredictionGuard is RAII -- drops when connection closes
let guard = handle.sync_guard();
let result = service.predict(slot, input).await;
// If connection drops here, guard.drop() cancels the prediction
Async predictions (via Prefer: respond-async) are unaffected -- they continue running regardless of client connection state, delivering results via webhook.
The /health-check endpoint always returns HTTP 200 with the status in the JSON body. This allows load balancers and orchestrators to distinguish between "server is running but not ready" vs "server is down."
| State | JSON status | Condition |
|---|---|---|
UNKNOWN | "UNKNOWN" | Process just started, not yet serving |
STARTING | "STARTING" | Worker subprocess initializing, running setup() |
READY | "READY" | Worker ready, slots available |
BUSY | "BUSY" | All slots occupied (backpressure) |
SETUP_FAILED | "SETUP_FAILED" | setup() threw exception |
DEFUNCT | "DEFUNCT" | Fatal error, worker crashed |
UNHEALTHY | "UNHEALTHY" | User-defined healthcheck failed (transient) |
When all concurrency slots are occupied, new predictions receive 409 Conflict instead of queuing. Clients should implement retry with backoff.
Prediction endpoints return 503 when health is not READY.
PUT /predictions/{id} is idempotent -- if the prediction already exists, it returns the current state. If not, it creates a new one. This is backed by a concurrent DashMap, so it's thread-safe without locks and safe under concurrent requests with the same ID.
The runtime uses explicit permit tokens for concurrency control:
// Acquire permit (blocks if all slots busy)
let permit = permit_pool.acquire().await?;
// Permit is held during prediction
let slot_id = permit.slot_id();
let result = orchestrator.predict(slot_id, input).await;
// Permit automatically returned on drop
drop(permit);
Advantages:
# cog.yaml
concurrency:
max: 5
This creates 5 slots in the PermitPool. Each slot corresponds to one Unix socket connection to the worker subprocess.
flowchart LR
subgraph request["Incoming Request"]
json["JSON body"]
end
subgraph validation["Validation"]
schema["Schema (Input type)"]
validate["Schema Validation"]
end
subgraph transform["Transformation"]
download["Download URLs → Files"]
coerce["Type Coercion"]
end
subgraph predict["predict()"]
kwargs["**kwargs"]
end
json --> validate
schema --> validate
validate --> download
download --> coerce
coerce --> kwargs
input from request bodycog.Path fields are fetched to local temp files**kwargsflowchart LR
subgraph predict["predict()"]
result["Return value / yields"]
end
subgraph transform["Transformation"]
upload["Upload files → URLs"]
serialize["JSON serialization"]
end
subgraph response["Response"]
output["output field"]
end
result --> upload
upload --> serialize
serialize --> output
cog.Path outputs are uploaded, replaced with URLsoutput field of responseInput files (cog.Path):
Client sends: {"input": {"image": "https://example.com/photo.jpg"}}
Server downloads: /tmp/inputabc123.jpg
predict() sees: image = Path("/tmp/inputabc123.jpg")
Output files (cog.Path):
predict() returns: Path("/tmp/output.png")
Server uploads: https://storage.example.com/output-xyz.png
Client receives: {"output": "https://storage.example.com/output-xyz.png"}
Cancellation uses IPC messages with different strategies for sync vs async predictors:
flowchart TD
parent["Parent#colon; ControlRequest#colon;#colon;Cancel { slot }"]
parent --> worker["Worker#colon; handler.cancel(slot)"]
subgraph sync ["Sync predictors"]
direction TB
s1["Set CANCEL_REQUESTED flag for slot"]
s2["Send SIGUSR1 to self"]
s3["Signal handler#colon; raise KeyboardInterrupt\n(if in cancelable region)"]
end
subgraph async_p ["Async predictors"]
direction TB
a1["Get future from slot state"]
a2["future.cancel() → Python raises\nasyncio.CancelledError"]
end
worker --> sync
worker --> async_p
For async predictions, progress is delivered via webhooks:
sequenceDiagram
participant Client
participant Cog
participant Webhook
Client->>Cog: POST /predictions (Prefer: respond-async)
Cog-->>Client: 202 {status: "starting"}
Cog->>Webhook: {status: "starting"}
Note over Cog: predict() starts
Cog->>Webhook: {status: "processing"}
loop Output yields
Cog->>Webhook: {output: "partial...", logs: "..."}
end
Cog->>Webhook: {status: "succeeded", output: "final"}
| Event | When | Payload Contains |
|---|---|---|
start | Prediction begins | status: starting |
output | Each yield from iterator | Partial output |
logs | Log lines captured | Updated logs |
completed | Prediction finishes | Final status, output, metrics |
Filter events with webhook_events_filter:
{
"input": {...},
"webhook": "https://...",
"webhook_events_filter": ["completed"]
}
Webhook delivery includes structured retry with exponential backoff and automatic OpenTelemetry trace context propagation in headers.
For models that yield output progressively:
def predict(self, prompt: str) -> Iterator[str]:
for token in generate(prompt):
yield token
The API can deliver these as:
output webhookAccept: text/event-streamoutput: ["a", "b", "c"]crates/coglet/src/transport/http/ -- HTTP route handlers, request parsing, response constructioncrates/coglet/src/prediction.rs -- the Prediction state machine (status transitions, webhook firing)crates/coglet/src/webhook.rs -- webhook delivery, retry logic, trace context propagationcrates/coglet/src/bridge/ -- IPC protocol definitions and transportcrates/coglet/src/permit/ -- slot-based concurrency control