diff options
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java | 40 |
1 files changed, 0 insertions, 40 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java index c5355ebdf6f..fc9785f8cc0 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/TestableModel.java @@ -2,7 +2,6 @@ package ai.vespa.rankingexpression.importer.onnx; import ai.vespa.rankingexpression.importer.ImportedModel; -import ai.vespa.rankingexpression.importer.tensorflow.TensorConverter; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; @@ -14,16 +13,9 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.util.List; import java.util.Map; -import static org.junit.Assert.assertEquals; - public class TestableModel { Tensor evaluateVespa(ImportedModel model, String operationName, Map<String, TensorType> inputs) { @@ -39,38 +31,6 @@ public class TestableModel { return expression.evaluate(context).asTensor(); } - Tensor evaluateTF(SavedModelBundle tensorFlowModel, String operationName, Map<String, TensorType> inputs) { - Session.Runner runner = tensorFlowModel.session().runner(); - for (Map.Entry<String, TensorType> entry : inputs.entrySet()) { - try { - runner.feed(entry.getKey(), tensorFlowFloatInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue())); - } catch (Exception e) { - runner.feed(entry.getKey(), tensorFlowDoubleInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue())); - } - } - List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); - assertEquals(1, results.size()); - return TensorConverter.toVespaTensor(results.get(0)); - } - - private org.tensorflow.Tensor<?> tensorFlowFloatInputArgument(int d0Size, int d1Size) { - FloatBuffer fb1 = FloatBuffer.allocate(d0Size * d1Size); - int i = 0; - for (int d0 = 0; d0 < d0Size; d0++) - for (int d1 = 0; d1 < d1Size; ++d1) - fb1.put(i++, (float)(d1 * 1.0 / d1Size)); - return org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb1); - } - - private org.tensorflow.Tensor<?> tensorFlowDoubleInputArgument(int d0Size, int d1Size) { - DoubleBuffer fb1 = DoubleBuffer.allocate(d0Size * d1Size); - int i = 0; - for (int d0 = 0; d0 < d0Size; d0++) - for (int d1 = 0; d1 < d1Size; ++d1) - fb1.put(i++, (float)(d1 * 1.0 / d1Size)); - return org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb1); - } - private Tensor vespaInputArgument(int d0Size, int d1Size) { Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build()); for (int d0 = 0; d0 < d0Size; d0++) |