diff options
Diffstat (limited to 'config-model/src/test/java/com/yahoo/schema/derived/NeuralNetTestCase.java')
-rw-r--r-- | config-model/src/test/java/com/yahoo/schema/derived/NeuralNetTestCase.java | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/config-model/src/test/java/com/yahoo/schema/derived/NeuralNetTestCase.java b/config-model/src/test/java/com/yahoo/schema/derived/NeuralNetTestCase.java new file mode 100644 index 00000000000..6e584099331 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/schema/derived/NeuralNetTestCase.java @@ -0,0 +1,42 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema.derived; + +import com.yahoo.search.Query; +import com.yahoo.search.query.profile.compiled.CompiledQueryProfile; +import com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry; +import com.yahoo.search.query.profile.config.QueryProfileConfigurer; +import com.yahoo.schema.parser.ParseException; +import com.yahoo.vespa.model.container.search.QueryProfiles; +import org.junit.Test; + +import java.io.IOException; + +import com.yahoo.component.ComponentId; + +import static org.junit.Assert.assertEquals; + +public class NeuralNetTestCase extends AbstractExportingTestCase { + + @Test + public void testNeuralNet() throws IOException, ParseException { + ComponentId.resetGlobalCountersForTests(); + DerivedConfiguration c = assertCorrectDeriving("neuralnet"); + // Verify that query profiles end up correct when passed through the same intermediate forms as a full system + CompiledQueryProfileRegistry queryProfiles = + QueryProfileConfigurer.createFromConfig(new QueryProfiles(c.getQueryProfiles(), (level, message) -> {}).getConfig()).compile(); + assertNeuralNetQuery(c, queryProfiles.getComponent("default")); + } + + @Test + public void testNeuralNet_noQueryProfiles() throws IOException, ParseException { + ComponentId.resetGlobalCountersForTests(); + DerivedConfiguration c = assertCorrectDeriving("neuralnet_noqueryprofile"); + } + + private void assertNeuralNetQuery(DerivedConfiguration c, CompiledQueryProfile defaultprofile) { + Query q = new Query("?test=foo&ranking.features.query(b_1)=[1,2,3,4,5,6,7,8,9]", defaultprofile); + assertEquals("tensor(out[9]):[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]", + q.properties().get("ranking.features.query(b_1)").toString()); + } + +} |