diff options
author | Lester Solbakken <lesters@oath.com> | 2019-06-04 16:02:13 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-06-04 16:02:13 +0200 |
commit | b95686ff1c311ea59044869b687363f43cf54686 (patch) | |
tree | 7369e3ce1567e248a4298a752624ea28fac97698 /model-integration | |
parent | 91e553f9f9b9a1b5c159ac3ac649e40e103244fb (diff) |
Support a few more tensorflow operations
Diffstat (limited to 'model-integration')
3 files changed, 129 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index 419bc7ddf28..714953fbd45 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -71,6 +71,7 @@ class GraphImporter { case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square()); case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java new file mode 100644 index 00000000000..46b95233d11 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java @@ -0,0 +1,110 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +public class Sum extends IntermediateOperation { + + private final AttributeMap attributeMap; + private List<String> reduceDimensions; + + public Sum(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(2)) return null; + + IntermediateOperation reductionIndices = inputs.get(1); + if ( ! reductionIndices.getConstantValue().isPresent()) { + throw new IllegalArgumentException("Sum in " + name + ": Reduction indices must be a constant."); + } + Tensor indices = reductionIndices.getConstantValue().get().asTensor(); + reduceDimensions = new ArrayList<>(); + + OrderedTensorType inputType = inputs.get(0).type().get(); + for (Iterator<Tensor.Cell> cellIterator = indices.cellIterator(); cellIterator.hasNext();) { + Tensor.Cell cell = cellIterator.next(); + int dimensionIndex = cell.getValue().intValue(); + if (dimensionIndex < 0) { + dimensionIndex = inputType.dimensions().size() - dimensionIndex; + } + reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name()); + } + return reducedType(inputType, shouldKeepDimensions()); + } + + // optimization: if keepDims and one reduce dimension that has size 1: same as identity. + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputTypesPresent(2)) return null; + + TensorFunction inputFunction = inputs.get(0).function().get(); + TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.sum, reduceDimensions); + if (shouldKeepDimensions()) { + // multiply with a generated tensor created from the reduced dimensions + TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); + for (String name : reduceDimensions) { + typeBuilder.indexed(name, 1); + } + TensorType generatedType = typeBuilder.build(); + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); + Generate generatedFunction = new Generate(generatedType, + new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); + output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); + } + return output; + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + List<String> renamedDimensions = new ArrayList<>(reduceDimensions.size()); + for (String name : reduceDimensions) { + Optional<String> newName = renamer.dimensionNameOf(name); + if (!newName.isPresent()) { + return; // presumably, already renamed + } + renamedDimensions.add(newName.get()); + } + reduceDimensions = renamedDimensions; + } + + private boolean shouldKeepDimensions() { + Optional<Value> keepDims = attributeMap.get("keep_dims"); + return keepDims.isPresent() && keepDims.get().asBoolean(); + } + + private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); + for (TensorType.Dimension dimension: inputType.type().dimensions()) { + if ( ! reduceDimensions.contains(dimension.name())) { + builder.add(dimension); + } else if (keepDimensions) { + builder.add(TensorType.Dimension.indexed(dimension.name(), 1L)); + } + } + return builder.build(); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java index a07c0fdf4dc..7e305a4a0bd 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; +import ai.vespa.rankingexpression.importer.operations.Sum; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; @@ -71,25 +72,40 @@ class GraphImporter { case "switch": return new Switch(modelName, nodeName, inputs, nodePort); // math ops + case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); + case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); - case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); + case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); + case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); + case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); + case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log()); case "matmul": return new MatMul(modelName, nodeName, inputs); case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); case "mean": return new Mean(modelName, nodeName, inputs, attributes); case "reducemean": return new Mean(modelName, nodeName, inputs, attributes); case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "negate": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); + case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal()); case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt()); case "select": return new Select(modelName, nodeName, inputs); case "where3": return new Select(modelName, nodeName, inputs); case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); + case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference()); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "sum": return new Sum(modelName, nodeName, inputs, attributes); + case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square()); + case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); + case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); + case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); // nn ops case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); @@ -109,6 +125,7 @@ class GraphImporter { IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); op.warning("Operation '" + node.getOp() + "' is currently not implemented"); + System.out.println(node.getName() + ": operation '" + node.getOp() + "' is currently not implemented"); return op; } |