ADRs/0023 - UDFs.md
Implemented
Proposed by: Adam Gibson (31 Jan 2023)
Discussed with: Paul Dubs
Finalized by: Adam Gibson (2nd Feb 2023)
Users should be able to define their own custom operations in SameDiff, including custom gradients. Currently, defining a User-Defined Function (UDF) is not properly integrated into SameDiff and requires handling multiple aspects of the system, such as:
To support custom UDFs in SameDiff, the following components will be created:
These components work together to allow for the following:
SameDiff sd = SameDiff.create();
UserDefinedCustomOp userDefinedCustomOp = ...;
SDVariable[] opOutputs = sd.doUdf(userDefinedCustomOp);
When an operation is registered, it is saved and loaded with the graph like any other operation. Dynamic creation of operations via reflection when a graph is loaded, using the annotation scanning.
Below is an example:
@UserDefinedOp // Annotation for discovering custom ops to register
public class TestAddUdf extends UserDefinedCustomOp { // Class to extend
// Empty constructor. Used when creating a graph from flatbuffers in the underlying { org.nd4j.autodiff.samediff.serde.FlatBuffersMapper}.
public TestAddUdf() {
super();
}
// Other constructors can be whatever the user wishes. Custom ops usually take in a
// SameDiff instance and one or more SDVariable args. These are the minimum components to instantiate an op.
// Each of these calls super(...) to properly configure the op to be used within the SameDiff graph passed in.
public TestAddUdf(SameDiff sameDiff, SDVariable arg) {
super(sameDiff, arg);
}
public TestAddUdf(SameDiff sameDiff, SDVariable[] args) {
super(sameDiff, args);
}
// Used to calculate output variables when registering an op with a graph.
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
// A user must implement this method. It is used in SameDiff to determine the number of
// output variables needed when it can't be determined from getNumOutputs().
return Arrays.asList(dataTypes.get(0));
}
@Override
public void setPropertiesForFunction(Map<String, Object> properties) {
// A user can define properties as fields. If so, they must implement this method and propertiesForFunction().
// These are used to create an op from scratch when saving/loading a model.
}
@Override
public Map<String, Object> propertiesForFunction() {
// Returns properties (fields on the java class) as a map. Properties can be any value that
// is a field on the op itself. These properties are optional and may not be needed,
// depending on the op. All properties will end up being passed to the underlying iArguments,
// tArguments, and other associated data structures inherited from DynamicCustomOp.
return Collections.emptyMap();
}
@Override
public int getNumOutputs() {
// Returns the number of outputs for the op. If an op has a variable number of outputs,
// a user will need to use an SDVariable.eval() call to return an int to determine the number of outputs.
return 1;
}
@Override
public String opName() {
// The op name, required for proper registration with the registry.
return "test_add_udf";
}
@Override
public void configureFromArguments() {
// A hook for configuring the op after creation. Used for configuration from specified arguments,
// such as ints, floats/doubles, and input variables. The arguments referenced are the underlying
// arguments that get passed to every c/c++ ops, including iArguments, tArguments, dArguments,
// inputArguments, and outputArguments.
}
@Override
public void configureWithSameDiff(SameDiff sameDiff) {
this.sameDiff = sameDiff;
// Implemented this method for handling initialization after the op is created. It initiates values using relevant
// SameDiff metadata, such as obtaining input and output argument metadata from SDVariable found as args().
}
@Override
public boolean isInplaceCall() {
// Indicates whether the inputs are also the outputs.
// Note that extra care should be taken to avoid bugs when an operation is in-place.
// This is particularly important when an input to an operation is a view.
return false;
}
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
// Describes how to calculate the output shape based on the inputs.
// Note that calculateOutputShape is called when dynamically creating output arrays to store the result
// of an operation's execution.
// It is not called when an operation is in-place.
return Arrays.asList(inputArguments.get(0).shapeDescriptor());
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
// Describes how to calculate the output shape based on the inputs from the operation context.
// Note that calculateOutputShape is called when dynamically creating output arrays to store
// the result of an operation's execution.
// This is different from the above method as the inputs are obtained from the operation
// context instead of the operation itself.
// It is not called when an operation is in-place.
return Arrays.asList(oc.getInputArrays().get(0).shapeDescriptor());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
// The doDiff method must be implemented by the user if the operation is to be used for training.
// It should return one gradient for each input.
return new AddBpOp(sameDiff, larg(), rarg(), f1.get(0)).outputs();
}
@Override
public void exec() {
// The exec method for the operation itself, consisting of operation execution
// and setting the outputs for the operation.
AddOp addOp = new AddOp();
addOp.addInputArgument(inputArguments.get(0), inputArguments.get(1));
Nd4j.getExecutioner().exec(addOp);
this.outputArguments.addAll(addOp.outputArguments);
}
@Override
public void exec(OpContext opContext) {
// The exec method for the operation itself, consisting of operation execution and
// setting the outputs for the operation context.
Nd4j.getExecutioner().exec(new AddOp(), opContext);
}
}
With the above definition, a user just has to pass in a created op as an instantiated object. As long as an op is annotated it is properly integrated with the samediff graph.
When executing, the special code paths in the op executioners will call exec() or exec(opContext)