docs/contributing/contributing-ai-functions.md
This page contains references and guides for developing new AI Functions in daft. These steps will guide you through implementing a model expression like:
embed_textembed_imageclassify_textpromptAll model expressions are backed by a Protocol and Descriptor. These protocols are based on the verb and modality, and the descriptor is used to instantiate the model at runtime.
# daft.ai.protocols
@runtime_checkable
class TextClassifier(Protocol):
"""Protocol for text classification implementations."""
def classify_text(self, text: list[str], labels: list[str]) -> list[str]:
"""Classifies a batch of text strings using the given label(s)."""
...
class TextClassifierDescriptor(Descriptor[TextClassifier]):
"""Descriptor for a TextClassifier implementation."""
You must update the Provider interface with a new method to create your descriptor. This should have a default implementation which simply raises; this makes it so that you need not update all existing providers.
# daft.ai.provider
class Provider(ABC):
# ... existing code
def get_text_classifier(self, model: str | None = None, **options: Any) -> TextClassifierDescriptor:
"""Returns a TextClassifierDescriptor for this provider."""
raise not_implemented_err(self, method="classify_text")
In daft.functions.ai you can add the function, and then re-export it in daft.functions.__init__.py.
The implementation is responsible for resolving the provider from the given arguments, then you
will call the appropriate provider method to get the relevant descriptor.
import daft
from daft import DataType, Series
from daft.ai.protocols import TextClassifier, TextClassifierDescriptor
def classify_text(
text: Expression,
labels: list[str],
*,
provider: str | Provider | None = None,
model: str | None = None,
) -> Expression:
# Load a TextClassifierDescriptor from the resolved provider
text_classifier = _resolve_provider(provider, "transformers").get_text_classifier(model)
# Create the stateful class UDF
classifier = _TextClassifierExpression(text_classifier, labels)
# Return the expression
return classifier.classify(text)
@daft.cls
class _TextClassifierExpression:
"""Function expression implementation for a TextClassifier protocol."""
def __init__(self, descriptor: TextClassifierDescriptor, labels: list[str]):
# Instantiate from the descriptor in __init__
self.text_classifier = descriptor.instantiate()
self.labels = labels
@daft.method.batch(return_dtype=DataType.string())
def classify(self, text: Series) -> list[str]:
text_list = text.to_pylist()
if not text_list:
return []
return self.text_classifier.classify_text(text_list, self.labels)
Here is a simplified example implementation of embed_text for OpenAI. This should give you and idea of where you actual logic should live, and the previous steps are to properly hook your new expression into the provider/model system.
@dataclass
class OpenAITextEmbedderDescriptor(TextEmbedderDescriptor):
model: str # store some metadata
# We can use the stored metadata to instantiate the protocol implementation
def instantiate(self) -> TextEmbedder:
return OpenAITextEmbedder(client=OpenAI(), model=self.model)
@dataclass
class OpenAITextEmbedder(TextEmbedder):
client: OpenAI
model: str
# This is a simple version using the batch API. The full implementation
# uses dynamic batching and has error handling mechanisms.
def embed_text(self, text: list[str]) -> list[Embedding]:
response = self.client.embeddings.create(
input=text,
model=self.model,
encoding_format="float",
)
return [np.array(embedding.embedding) for embedding in response.data]
You can now use this like any other expression.
import daft
df = daft.read_parquet("/path/to/file.parquet") # assuming has some column 'text'
df = df.with_column("embedding", embed_text(df["text"], provider="openai")) # <- set provider to 'openai'
df.show()