summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-12-05 09:18:52 +0100
committerLester Solbakken <lesters@oath.com>2019-12-05 09:18:52 +0100
commit9e22a529219670ce25b49666de95cd062686b210 (patch)
tree1037f0db07611300710460f1cc0ad4f9f250e0de /model-integration
parentcd4e23a47c1993d5c9dbe17dfb23bdce3e037844 (diff)
Fix headers and comments
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java1
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java2
5 files changed, 6 insertions, 9 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index 63b04470d00..55f5d979ea8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -24,7 +24,6 @@ import ai.vespa.rankingexpression.importer.operations.NoOp;
import ai.vespa.rankingexpression.importer.operations.Reshape;
import ai.vespa.rankingexpression.importer.operations.Shape;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java
index 497e7e7550d..ea6bb2eaf99 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java
@@ -1,4 +1,4 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
@@ -7,7 +7,6 @@ import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
-import java.util.Optional;
public class ConcatReduce extends IntermediateOperation {
@@ -22,7 +21,7 @@ public class ConcatReduce extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(inputs.size())) return null;
- return inputs.get(0).type().get(); // todo, not necessarily so. Broadcasting etc?
+ return inputs.get(0).type().get();
}
@Override
@@ -66,7 +65,6 @@ public class ConcatReduce extends IntermediateOperation {
return a.rank() < b.rank() ? a : b;
}
-
@Override
public ConcatReduce withInputs(List<IntermediateOperation> inputs) {
return new ConcatReduce(modelName(), name(), inputs, aggregator);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
index b3fe1da931e..7af051484f5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
@@ -63,7 +63,7 @@ public class Reduce extends IntermediateOperation {
for (Value i : attributeMap.getList("axes").get()) {
int dimensionIndex = (int) i.asDouble();
if (dimensionIndex < 0) {
- dimensionIndex = inputType.dimensions().size() - (-1 * dimensionIndex);
+ dimensionIndex = inputType.dimensions().size() + dimensionIndex;
}
reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index 1b72565b423..c88fc18e6c6 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -56,7 +56,7 @@ public class Reshape extends IntermediateOperation {
List<Integer> dimSizes = new ArrayList<>(shape.type().rank());
shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue()));
- // first pass - set 0 values
+ // first pass - set 0 values, meaning that size is retained from input
for (int i = 0; i < dimSizes.size(); ++i) {
if (dimSizes.get(i) == 0) {
if (i >= inputType.dimensions().size()) {
@@ -66,7 +66,7 @@ public class Reshape extends IntermediateOperation {
}
}
- // second pass - set any -1 values
+ // second pass - set any -1 value, meaning that the dimension size should be expanded to fill the tensor
for (int i = 0; i < dimSizes.size(); ++i) {
if (dimSizes.get(i) < 0) {
int shapeSize = dimSizes.stream().reduce(1, (a, b) -> a * b);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
index 306387ad206..83086926316 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
@@ -37,7 +37,7 @@ public class Softmax extends IntermediateOperation {
OrderedTensorType inputType = inputs.get(0).type().get();
- int axis = inputType.rank() == 1 ? 0 : 1; // assumption: first dimension is batch dimension, except if there's only one dimension
+ int axis = inputType.rank() == 1 ? 0 : 1; // assumption: first dimension is batch dimension
if (attributeMap.get("axis").isPresent()) {
axis = (int)attributeMap.get("axis").get().asDouble();
}