aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java117
1 files changed, 117 insertions, 0 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java
new file mode 100644
index 00000000000..07634166060
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java
@@ -0,0 +1,117 @@
+package com.yahoo.searchlib.rankingexpression.integration.ml;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+
+import java.nio.FloatBuffer;
+import java.util.List;
+
+import static com.yahoo.searchlib.rankingexpression.integration.ml.TestableTensorFlowModel.contextFrom;
+
+/**
+ * Microbenchmark of imported ML model evaluation.
+ *
+ * @author lesters
+ */
+public class BlogEvaluationBenchmark {
+
+ static final String modelDir = "src/test/files/integration/tensorflow/blog/saved";
+
+ public static void main(String[] args) throws ParseException {
+ SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
+ ImportedModel model = new TensorFlowImporter().importModel("blog", modelDir, tensorFlowModel);
+
+ Context context = contextFrom(model);
+ Tensor u = generateInputTensor();
+ Tensor d = generateInputTensor();
+ context.put("input_u", new TensorValue(u));
+ context.put("input_d", new TensorValue(d));
+
+ // Parse the ranking expression from imported string to force primitive tensor functions.
+ RankingExpression expression = new RankingExpression(model.expressions().get("y").getRoot().toString());
+ benchmarkJava(expression, context, 20, 200);
+
+ System.out.println("*** Optimizing expression ***");
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ OptimizationReport report = optimizer.optimize(expression, (ContextIndex)context);
+ System.out.println(report.toString());
+
+ benchmarkJava(expression, context, 2000, 20000);
+ benchmarkTensorFlow(tensorFlowModel, 2000, 20000);
+ }
+
+ private static void benchmarkJava(RankingExpression expression, Context context, int warmup, int iterations) {
+ System.out.println("*** Java evaluation - warmup ***");
+ evaluate(expression, context, warmup);
+ System.gc();
+ System.out.println("*** Java evaluation - " + iterations + " iterations ***");
+ double startTime = System.nanoTime();
+ evaluate(expression, context, iterations);
+ double endTime = System.nanoTime();
+ System.out.println("Model evaluation time is " + ((endTime-startTime) / (1000*1000)) + " ms");
+ System.out.println("Average model evaluation time is " + ((endTime-startTime) / (1000*1000)) / iterations + " ms");
+ }
+
+ private static double evaluate(RankingExpression expression, Context context, int iterations) {
+ double result = 0;
+ for (int i = 0 ; i < iterations; i++) {
+ result = expression.evaluate(context).asTensor().sum().asDouble();
+ }
+ return result;
+ }
+
+ private static Tensor generateInputTensor() {
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", 128).build());
+ for (int d0 = 0; d0 < 1; d0++)
+ for (int d1 = 0; d1 < 128; d1++)
+ b.cell(d1 * 1.0 / 128, d0, d1);
+ return b.build();
+ }
+
+ private static void benchmarkTensorFlow(SavedModelBundle tensorFlowModel, int warmup, int iterations) {
+ org.tensorflow.Tensor<?> u = generateInputTensorFlow();
+ org.tensorflow.Tensor<?> d = generateInputTensorFlow();
+
+ System.out.println("*** TensorFlow evaluation - warmup ***");
+ evaluateTensorflow(tensorFlowModel, u, d, warmup);
+
+ System.gc();
+ System.out.println("*** TensorFlow evaluation - " + iterations + " iterations ***");
+ double startTime = System.nanoTime();
+ evaluateTensorflow(tensorFlowModel, u, d, iterations);
+ double endTime = System.nanoTime();
+ System.out.println("Model evaluation time is " + ((endTime-startTime) / (1000*1000) + " ms"));
+ System.out.println("Average model evaluation time is " + ((endTime-startTime) / (1000*1000)) / iterations + " ms");
+ }
+
+ private static double evaluateTensorflow(SavedModelBundle tensorFlowModel, org.tensorflow.Tensor<?> u, org.tensorflow.Tensor<?> d, int iterations) {
+ double result = 0;
+ for (int i = 0 ; i < iterations; i++) {
+ Session.Runner runner = tensorFlowModel.session().runner();
+ runner.feed("input_u", u);
+ runner.feed("input_d", d);
+ List<org.tensorflow.Tensor<?>> results = runner.fetch("y").run();
+ result = TensorConverter.toVespaTensor(results.get(0)).sum().asDouble();
+ }
+ return result;
+ }
+
+ private static org.tensorflow.Tensor<?> generateInputTensorFlow() {
+ FloatBuffer fb = FloatBuffer.allocate(1 * 128);
+ for (int i = 0; i < 128; ++i) {
+ fb.put(i, (float)(i * 1.0 / 128));
+ }
+ return org.tensorflow.Tensor.create(new long[]{ 1, 128 }, fb);
+ }
+
+}