diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2019-01-14 15:50:39 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2019-01-14 15:50:39 +0100 |
commit | 2610ef8b06d6123d869a620eaf0b31f4e7c2860e (patch) | |
tree | 1d56866dbb811179c5fe6f6293cd943c10d09e95 /searchlib | |
parent | 3d0321eca4f93717e4afda679ca735b0b3535ce2 (diff) |
Handle bool nodes in grouping too.
Diffstat (limited to 'searchlib')
6 files changed, 189 insertions, 11 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/expression/BoolResultNode.java b/searchlib/src/main/java/com/yahoo/searchlib/expression/BoolResultNode.java new file mode 100644 index 00000000000..4fc6dab142d --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/expression/BoolResultNode.java @@ -0,0 +1,92 @@ +package com.yahoo.searchlib.expression; + +import com.yahoo.vespa.objects.Deserializer; +import com.yahoo.vespa.objects.ObjectVisitor; +import com.yahoo.vespa.objects.Serializer; + +import java.nio.ByteBuffer; + +public class BoolResultNode extends ResultNode { + public static final int classId = registerClass(0x4000 + 146, BoolResultNode.class); + private boolean value = false; + + public BoolResultNode() { + } + + public BoolResultNode(boolean value) { + this.value = value; + } + /** + * Sets the value of this result. + * + * @param value The value to set. + * @return This, to allow chaining. + */ + public BoolResultNode setValue(boolean value) { + this.value = value; + return this; + } + + @Override + protected int onGetClassId() { + return classId; + } + + @Override + protected void onSerialize(Serializer buf) { + byte v = (byte)(value ? 1 : 0); + buf.putByte(null, v ); + } + + @Override + protected void onDeserialize(Deserializer buf) { + value = buf.getByte(null) != 0; + } + + @Override + public long getInteger() { + return value ? 1 : 0; + } + + @Override + public double getFloat() { + return value ? 1.0 : 0.0; + } + + @Override + public String getString() { + return String.valueOf(value); + } + + @Override + public byte[] getRaw() { + return ByteBuffer.allocate(8).putLong(getInteger()).array(); + } + + @Override + public void negate() { + value = ! value; + } + + + @Override + protected int onCmp(ResultNode rhs) { + return Long.compare(getInteger(), rhs.getInteger()); + } + + @Override + public int hashCode() { + return super.hashCode() + (int)getInteger(); + } + + @Override + public void visitMembers(ObjectVisitor visitor) { + super.visitMembers(visitor); + visitor.visit("value", value); + } + + @Override + public void set(ResultNode rhs) { + value = rhs.getInteger() > 0; + } +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/expression/BoolResultNodeVector.java b/searchlib/src/main/java/com/yahoo/searchlib/expression/BoolResultNodeVector.java new file mode 100644 index 00000000000..a528669500d --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/expression/BoolResultNodeVector.java @@ -0,0 +1,68 @@ +package com.yahoo.searchlib.expression; + +import com.yahoo.vespa.objects.Deserializer; +import com.yahoo.vespa.objects.Serializer; + +import java.util.ArrayList; + +public class BoolResultNodeVector extends ResultNodeVector { + public static final int classId = registerClass(0x4000 + 147, BoolResultNodeVector.class); + private ArrayList<BoolResultNode> vector = new ArrayList<>(); + + public BoolResultNodeVector() {} + public BoolResultNodeVector add(BoolResultNode v) { + vector.add(v); + return this; + } + + public ArrayList<BoolResultNode> getVector() { + return vector; + } + @Override + public ResultNodeVector add(ResultNode r) { + return add((BoolResultNode)r); + } + + @Override + protected int onGetClassId() { + return classId; + } + + @Override + protected void onSerialize(Serializer buf) { + super.onSerialize(buf); + buf.putInt(null, vector.size()); + for (BoolResultNode node : vector) { + node.serialize(buf); + } + } + + @Override + protected void onDeserialize(Deserializer buf) { + super.onDeserialize(buf); + int sz = buf.getInt(null); + vector = new ArrayList<>(); + for (int i = 0; i < sz; i++) { + BoolResultNode node = new BoolResultNode(); + node.deserialize(buf); + vector.add(node); + } + } + + @Override + protected int onCmp(ResultNode rhs) { + if (classId != rhs.getClassId()) { + return (classId - rhs.getClassId()); + } + BoolResultNodeVector b = (BoolResultNodeVector)rhs; + int minLength = vector.size(); + if (b.vector.size() < minLength) { + minLength = b.vector.size(); + } + int diff = 0; + for (int i = 0; (diff == 0) && (i < minLength); i++) { + diff = vector.get(i).compareTo(b.vector.get(i)); + } + return (diff == 0) ? (vector.size() - b.vector.size()) : diff; + } +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/expression/Int8ResultNode.java b/searchlib/src/main/java/com/yahoo/searchlib/expression/Int8ResultNode.java index 88920323703..f240a2d5ef7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/expression/Int8ResultNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/expression/Int8ResultNode.java @@ -18,9 +18,7 @@ public class Int8ResultNode extends NumericResultNode { public static final int classId = registerClass(0x4000 + 104, Int8ResultNode.class); private byte value = 0; - @SuppressWarnings("UnusedDeclaration") public Int8ResultNode() { - // used by deserializer } /** diff --git a/searchlib/src/main/java/com/yahoo/searchlib/expression/Int8ResultNodeVector.java b/searchlib/src/main/java/com/yahoo/searchlib/expression/Int8ResultNodeVector.java index 33734c15ff1..edae250defe 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/expression/Int8ResultNodeVector.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/expression/Int8ResultNodeVector.java @@ -15,7 +15,7 @@ import java.util.ArrayList; public class Int8ResultNodeVector extends ResultNodeVector { public static final int classId = registerClass(0x4000 + 116, Int8ResultNodeVector.class); - private ArrayList<Int8ResultNode> vector = new ArrayList<Int8ResultNode>(); + private ArrayList<Int8ResultNode> vector = new ArrayList<>(); public Int8ResultNodeVector() { @@ -53,7 +53,7 @@ public class Int8ResultNodeVector extends ResultNodeVector { protected void onDeserialize(Deserializer buf) { super.onDeserialize(buf); int sz = buf.getInt(null); - vector = new ArrayList<Int8ResultNode>(); + vector = new ArrayList<>(); for (int i = 0; i < sz; i++) { Int8ResultNode node = new Int8ResultNode((byte)0); node.deserialize(buf); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/expression/IntegerResultNodeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/expression/IntegerResultNodeTestCase.java index 4e079ba2adb..bb2cd640c8f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/expression/IntegerResultNodeTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/expression/IntegerResultNodeTestCase.java @@ -33,7 +33,7 @@ public class IntegerResultNodeTestCase extends ResultNodeTest { assertThat(new Int16ResultNode().getClassId(), is(Int16ResultNode.classId)); assertThat(new Int32ResultNode().getClassId(), is(Int32ResultNode.classId)); assertThat(new IntegerResultNode().getClassId(), is(IntegerResultNode.classId)); - + assertThat(new BoolResultNode().getClassId(), is(BoolResultNode.classId)); } @Test @@ -80,6 +80,18 @@ public class IntegerResultNodeTestCase extends ResultNodeTest { } @Test + public void testBool() { + BoolResultNode node = new BoolResultNode(); + assertEquals(0, node.getInteger()); + assertEquals(0.0, node.getFloat(), 0.000000000001); + assertEquals("false", node.getString()); + node.setValue(true); + assertEquals(1, node.getInteger()); + assertEquals(1.0, node.getFloat(), 0.000000000001); + assertEquals("true", node.getString()); + } + + @Test public void testInt8() { Int8ResultNode node = new Int8ResultNode(); node.setValue((byte) 5); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/expression/ResultNodeVectorTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/expression/ResultNodeVectorTestCase.java index 2f43ad6843b..2fc1771ece0 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/expression/ResultNodeVectorTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/expression/ResultNodeVectorTestCase.java @@ -23,10 +23,17 @@ public class ResultNodeVectorTestCase extends ResultNodeTest { assertThat(new Int16ResultNodeVector().getClassId(), is(Int16ResultNodeVector.classId)); assertThat(new Int8ResultNodeVector().getClassId(), is(Int8ResultNodeVector.classId)); assertThat(new FloatResultNodeVector().getClassId(), is(FloatResultNodeVector.classId)); + assertThat(new BoolResultNodeVector().getClassId(), is(BoolResultNodeVector.classId)); } @Test public void testVectorAdd() { + BoolResultNodeVector b = new BoolResultNodeVector(); + b.add(new BoolResultNode(true)); + b.add(new BoolResultNode(false)); + b.add((ResultNode)new BoolResultNode(false)); + assertThat(b.getVector().size(), is(3)); + Int8ResultNodeVector i8 = new Int8ResultNodeVector(); i8.add(new Int8ResultNode((byte)9)); i8.add(new Int8ResultNode((byte)2)); @@ -157,11 +164,12 @@ public class ResultNodeVectorTestCase extends ResultNodeTest { @Test public void testSerialize() throws InstantiationException, IllegalAccessException { - assertCorrectSerialization(new FloatResultNodeVector().add(new FloatResultNode(1.1)).add(new FloatResultNode(3.3)), new FloatResultNodeVector()); - assertCorrectSerialization(new IntegerResultNodeVector().add(new IntegerResultNode(1)).add(new IntegerResultNode(3)), new IntegerResultNodeVector()); - assertCorrectSerialization(new Int16ResultNodeVector().add(new Int16ResultNode((short) 1)).add(new Int16ResultNode((short) 3)), new Int16ResultNodeVector()); - assertCorrectSerialization(new Int8ResultNodeVector().add(new Int8ResultNode((byte) 1)).add(new Int8ResultNode((byte) 3)), new Int8ResultNodeVector()); - assertCorrectSerialization(new StringResultNodeVector().add(new StringResultNode("foo")).add(new StringResultNode("bar")), new StringResultNodeVector()); - assertCorrectSerialization(new RawResultNodeVector().add(new RawResultNode(new byte[]{6, 9})).add(new RawResultNode(new byte[]{9, 6})), new RawResultNodeVector()); + assertCorrectSerialization(new FloatResultNodeVector().add(new FloatResultNode(1.1)).add(new FloatResultNode(3.3)), new FloatResultNodeVector()); + assertCorrectSerialization(new IntegerResultNodeVector().add(new IntegerResultNode(1)).add(new IntegerResultNode(3)), new IntegerResultNodeVector()); + assertCorrectSerialization(new Int16ResultNodeVector().add(new Int16ResultNode((short) 1)).add(new Int16ResultNode((short) 3)), new Int16ResultNodeVector()); + assertCorrectSerialization(new Int8ResultNodeVector().add(new Int8ResultNode((byte) 1)).add(new Int8ResultNode((byte) 3)), new Int8ResultNodeVector()); + assertCorrectSerialization(new StringResultNodeVector().add(new StringResultNode("foo")).add(new StringResultNode("bar")), new StringResultNodeVector()); + assertCorrectSerialization(new RawResultNodeVector().add(new RawResultNode(new byte[]{6, 9})).add(new RawResultNode(new byte[]{9, 6})), new RawResultNodeVector()); + assertCorrectSerialization(new BoolResultNodeVector().add(new BoolResultNode(true)).add(new BoolResultNode(false)), new BoolResultNodeVector()); } } |