aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java40
1 files changed, 0 insertions, 40 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java
index c5355ebdf6f..fc9785f8cc0 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java
@@ -2,7 +2,6 @@
package ai.vespa.rankingexpression.importer.onnx;
import ai.vespa.rankingexpression.importer.ImportedModel;
-import ai.vespa.rankingexpression.importer.tensorflow.TensorConverter;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
@@ -14,16 +13,9 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import org.tensorflow.SavedModelBundle;
-import org.tensorflow.Session;
-import java.nio.DoubleBuffer;
-import java.nio.FloatBuffer;
-import java.util.List;
import java.util.Map;
-import static org.junit.Assert.assertEquals;
-
public class TestableModel {
Tensor evaluateVespa(ImportedModel model, String operationName, Map<String, TensorType> inputs) {
@@ -39,38 +31,6 @@ public class TestableModel {
return expression.evaluate(context).asTensor();
}
- Tensor evaluateTF(SavedModelBundle tensorFlowModel, String operationName, Map<String, TensorType> inputs) {
- Session.Runner runner = tensorFlowModel.session().runner();
- for (Map.Entry<String, TensorType> entry : inputs.entrySet()) {
- try {
- runner.feed(entry.getKey(), tensorFlowFloatInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue()));
- } catch (Exception e) {
- runner.feed(entry.getKey(), tensorFlowDoubleInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue()));
- }
- }
- List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
- assertEquals(1, results.size());
- return TensorConverter.toVespaTensor(results.get(0));
- }
-
- private org.tensorflow.Tensor<?> tensorFlowFloatInputArgument(int d0Size, int d1Size) {
- FloatBuffer fb1 = FloatBuffer.allocate(d0Size * d1Size);
- int i = 0;
- for (int d0 = 0; d0 < d0Size; d0++)
- for (int d1 = 0; d1 < d1Size; ++d1)
- fb1.put(i++, (float)(d1 * 1.0 / d1Size));
- return org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb1);
- }
-
- private org.tensorflow.Tensor<?> tensorFlowDoubleInputArgument(int d0Size, int d1Size) {
- DoubleBuffer fb1 = DoubleBuffer.allocate(d0Size * d1Size);
- int i = 0;
- for (int d0 = 0; d0 < d0Size; d0++)
- for (int d1 = 0; d1 < d1Size; ++d1)
- fb1.put(i++, (float)(d1 * 1.0 / d1Size));
- return org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb1);
- }
-
private Tensor vespaInputArgument(int d0Size, int d1Size) {
Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build());
for (int d0 = 0; d0 < d0Size; d0++)