ADRs/0020 - New Control flow.md
Discussion
Proposed by: Adam Gibson (10th April 2022)
Samediff supports control flow such as if statements and while loops. However this is not enough and common looping structures are still hard to use. Onnx has introduced a Loop operation. Loop requires a graph with pre configured graph. The graph takes in and outputs:
It approximates a for loop with the following code:
boolean cond = ...;
int maxIterations = ...;
for(int i = 0; i < maxIterations && cond; i++) {
loop body...
}
The loop body is represented as a sub graph attribute on the operation.
Similar to onnx's loop operation coupled with Invoke we provide a new loop that leverages invoke and some fixed conventions of the graph to use a loop body:
/**
* Loop with conditions.
* For more information see the underlying class
* {@link ControlFlow#loopWithConditions(String[], String, SameDiff, SameDiff, String, SDVariable[], String[], String[])}
* @param loopParams the loop parameters to loop with
* @return
*/
public SDVariable[] loopWithConditions(ControlFlow.LoopParams loopParams) {
LoopParams looks like the following:
public static class LoopParams {
private String[] outputVarNames;
private String loopName;
private SameDiff parent;
private SameDiff functionBody;
private String functionName;
private SDVariable[] loopVars;
private String[] functionBodyInputs;
private String[] functionBodyOutputs;
}
LoopParams has the following fields:
//setup the parent graph to pass inputs to the lambda
SameDiff parent = SameDiff.create();
SDVariable input = parent.placeHolder("input",DataType.FLOAT);
//setup the loop body
SameDiff loopBody = SameDiff.create();
SDVariable loopInput = loopBody.placeHolder("input", DataType.FLOAT);
SDVariable output = loopBody.math().add("output",loopInput,1.0);
//initialize the control flow with the default parameters such as the current iteration, the max number of iterations and the conditional output from the graph
SDVariable[] args = ControlFlow.initializeLoopBody(new String[]{"curr_iteration", "max_iterations", "cond_in"}, parent, 5, true);
SDVariable[] childArgs = ControlFlow.initializeLoopBody(new String[]{"curr_iteration", "max_iterations", "cond_in"}, loopBody, 5, true);
//input names for the input graph with the 4th input being the input from the parent
String[] inputNames = {
"curr_iteration",
"max_iterations",
"cond_in",
"input"
};
//output names from the output of the lmabda with the 4th being the result of the lamda's application of the input
String[] outputNames = {
"curr_iteration",
"max_iterations",
"cond_in",
"output"
};
//setup the loop variables for input
SDVariable[] finalArgs = new SDVariable[args.length + 1];
for(int i = 0; i < args.length; i++) {
finalArgs[i] = args[i];
}
finalArgs[3] = input;
//put it all together in the loop parameters
ControlFlow.LoopParams loopParams = ControlFlow.LoopParams.builder()
.parent(parent)
.functionBody(loopBody)
.functionBodyInputs(inputNames)
.functionBodyOutputs(outputNames)
.loopVars(finalArgs)
.loopName("loop")
.functionName("func")
.build();
//control the output parameter names
String[] finalOutputNames = new String[outputNames.length];
for(int i = 0; i < finalOutputNames.length; i++) {
finalOutputNames[i] = outputNames[i] + "_final";
}
//test the output variables, the names will match the specified output names
SDVariable[] loopWithConditions = parent.loopWithConditions(finalOutputNames,loopParams);
INDArray assertion = Nd4j.ones(5).addi(5);
Map<String, INDArray> output2 = parent.output(Collections.singletonMap("input", Nd4j.ones(5)), "output_final");