summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-06-04 16:02:13 +0200
committerLester Solbakken <lesters@oath.com>2019-06-04 16:02:13 +0200
commitb95686ff1c311ea59044869b687363f43cf54686 (patch)
tree7369e3ce1567e248a4298a752624ea28fac97698 /model-integration
parent91e553f9f9b9a1b5c159ac3ac649e40e103244fb (diff)
Support a few more tensorflow operations
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java1
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java110
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java19
3 files changed, 129 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index 419bc7ddf28..714953fbd45 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -71,6 +71,7 @@ class GraphImporter {
case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square());
case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
new file mode 100644
index 00000000000..46b95233d11
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
@@ -0,0 +1,110 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Optional;
+
+public class Sum extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+ private List<String> reduceDimensions;
+
+ public Sum(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(2)) return null;
+
+ IntermediateOperation reductionIndices = inputs.get(1);
+ if ( ! reductionIndices.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("Sum in " + name + ": Reduction indices must be a constant.");
+ }
+ Tensor indices = reductionIndices.getConstantValue().get().asTensor();
+ reduceDimensions = new ArrayList<>();
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ for (Iterator<Tensor.Cell> cellIterator = indices.cellIterator(); cellIterator.hasNext();) {
+ Tensor.Cell cell = cellIterator.next();
+ int dimensionIndex = cell.getValue().intValue();
+ if (dimensionIndex < 0) {
+ dimensionIndex = inputType.dimensions().size() - dimensionIndex;
+ }
+ reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
+ }
+ return reducedType(inputType, shouldKeepDimensions());
+ }
+
+ // optimization: if keepDims and one reduce dimension that has size 1: same as identity.
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputTypesPresent(2)) return null;
+
+ TensorFunction inputFunction = inputs.get(0).function().get();
+ TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.sum, reduceDimensions);
+ if (shouldKeepDimensions()) {
+ // multiply with a generated tensor created from the reduced dimensions
+ TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
+ for (String name : reduceDimensions) {
+ typeBuilder.indexed(name, 1);
+ }
+ TensorType generatedType = typeBuilder.build();
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
+ Generate generatedFunction = new Generate(generatedType,
+ new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
+ output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
+ }
+ return output;
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ List<String> renamedDimensions = new ArrayList<>(reduceDimensions.size());
+ for (String name : reduceDimensions) {
+ Optional<String> newName = renamer.dimensionNameOf(name);
+ if (!newName.isPresent()) {
+ return; // presumably, already renamed
+ }
+ renamedDimensions.add(newName.get());
+ }
+ reduceDimensions = renamedDimensions;
+ }
+
+ private boolean shouldKeepDimensions() {
+ Optional<Value> keepDims = attributeMap.get("keep_dims");
+ return keepDims.isPresent() && keepDims.get().asBoolean();
+ }
+
+ private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
+ for (TensorType.Dimension dimension: inputType.type().dimensions()) {
+ if ( ! reduceDimensions.contains(dimension.name())) {
+ builder.add(dimension);
+ } else if (keepDimensions) {
+ builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
+ }
+ }
+ return builder.build();
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
index a07c0fdf4dc..7e305a4a0bd 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer.tensorflow;
+import ai.vespa.rankingexpression.importer.operations.Sum;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
@@ -71,25 +72,40 @@ class GraphImporter {
case "switch": return new Switch(modelName, nodeName, inputs, nodePort);
// math ops
+ case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
+ case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
- case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
+ case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
+ case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
+ case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log());
case "matmul": return new MatMul(modelName, nodeName, inputs);
case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
case "mean": return new Mean(modelName, nodeName, inputs, attributes);
case "reducemean": return new Mean(modelName, nodeName, inputs, attributes);
case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "negate": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
+ case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt());
case "select": return new Select(modelName, nodeName, inputs);
case "where3": return new Select(modelName, nodeName, inputs);
case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
+ case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference());
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "sum": return new Sum(modelName, nodeName, inputs, attributes);
+ case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square());
+ case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
+ case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
+ case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
// nn ops
case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
@@ -109,6 +125,7 @@ class GraphImporter {
IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
op.warning("Operation '" + node.getOp() + "' is currently not implemented");
+ System.out.println(node.getName() + ": operation '" + node.getOp() + "' is currently not implemented");
return op;
}