aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@broadpark.no>2016-09-27 10:30:43 +0200
committerGitHub <noreply@github.com>2016-09-27 10:30:43 +0200
commit5824f886e83bbc7330e3ea7716a4751cdff0f9f6 (patch)
tree5bc05f79fef038aaeda67f4b6426dec5b3b28846
parent9bf07559625d9c600f38956bf0316307ebe62dbb (diff)
parenta0a19332961c4c0fab8809cc4a9e4d1f719b2f58 (diff)
Merge pull request #716 from yahoo/geirst/implement-to-spec-function-for-dense-tensor
Implement function to create a eval::TensorSpec from a tensor::DenseT…
-rw-r--r--vespalib/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp79
-rw-r--r--vespalib/src/vespa/vespalib/eval/tensor_spec.cpp11
-rw-r--r--vespalib/src/vespa/vespalib/eval/tensor_spec.h6
-rw-r--r--vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.cpp28
-rw-r--r--vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.h2
-rw-r--r--vespalib/src/vespa/vespalib/tensor/tensor.h4
-rw-r--r--vespalib/src/vespa/vespalib/test/insertion_operators.h18
7 files changed, 130 insertions, 18 deletions
diff --git a/vespalib/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp b/vespalib/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp
index 8478d46e1f4..a61513fbf9d 100644
--- a/vespalib/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp
+++ b/vespalib/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp
@@ -4,11 +4,11 @@
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/vespalib/tensor/dense/dense_tensor_builder.h>
#include <vespa/vespalib/util/exceptions.h>
-#include <algorithm>
using namespace vespalib::tensor;
using vespalib::IllegalArgumentException;
using Builder = DenseTensorBuilder;
+using vespalib::eval::TensorSpec;
void
assertTensor(const DenseTensor::DimensionsMeta &expDims,
@@ -20,33 +20,72 @@ assertTensor(const DenseTensor::DimensionsMeta &expDims,
EXPECT_EQUAL(expCells, realTensor.cells());
}
+void
+assertTensorSpec(const TensorSpec &expSpec, const Tensor &tensor)
+{
+ TensorSpec actSpec = tensor.toSpec();
+ EXPECT_EQUAL(expSpec.type(), actSpec.type());
+ EXPECT_EQUAL(expSpec.cells(), actSpec.cells());
+}
+
struct Fixture
{
Builder builder;
};
+Tensor::UP
+build1DTensor(Builder &builder)
+{
+ Builder::Dimension dimX = builder.defineDimension("x", 3);
+ builder.addLabel(dimX, 0).addCell(10).
+ addLabel(dimX, 1).addCell(11).
+ addLabel(dimX, 2).addCell(12);
+ return builder.build();
+}
+
TEST_F("require that 1d tensor can be constructed", Fixture)
{
- Builder::Dimension dimX = f.builder.defineDimension("x", 3);
- f.builder.addLabel(dimX, 0).addCell(10).
- addLabel(dimX, 1).addCell(11).
- addLabel(dimX, 2).addCell(12);
- assertTensor({{"x",3}}, {10,11,12},
- *f.builder.build());
+ assertTensor({{"x",3}}, {10,11,12}, *build1DTensor(f.builder));
+}
+
+TEST_F("require that 1d tensor can be converted to tensor spec", Fixture)
+{
+ assertTensorSpec(TensorSpec("tensor(x[3])").
+ add({{"x", 0}}, 10).
+ add({{"x", 1}}, 11).
+ add({{"x", 2}}, 12),
+ *build1DTensor(f.builder));
+}
+
+Tensor::UP
+build2DTensor(Builder &builder)
+{
+ Builder::Dimension dimX = builder.defineDimension("x", 3);
+ Builder::Dimension dimY = builder.defineDimension("y", 2);
+ builder.addLabel(dimX, 0).addLabel(dimY, 0).addCell(10).
+ addLabel(dimX, 0).addLabel(dimY, 1).addCell(11).
+ addLabel(dimX, 1).addLabel(dimY, 0).addCell(12).
+ addLabel(dimX, 1).addLabel(dimY, 1).addCell(13).
+ addLabel(dimX, 2).addLabel(dimY, 0).addCell(14).
+ addLabel(dimX, 2).addLabel(dimY, 1).addCell(15);
+ return builder.build();
}
TEST_F("require that 2d tensor can be constructed", Fixture)
{
- Builder::Dimension dimX = f.builder.defineDimension("x", 3);
- Builder::Dimension dimY = f.builder.defineDimension("y", 2);
- f.builder.addLabel(dimX, 0).addLabel(dimY, 0).addCell(10).
- addLabel(dimX, 0).addLabel(dimY, 1).addCell(11).
- addLabel(dimX, 1).addLabel(dimY, 0).addCell(12).
- addLabel(dimX, 1).addLabel(dimY, 1).addCell(13).
- addLabel(dimX, 2).addLabel(dimY, 0).addCell(14).
- addLabel(dimX, 2).addLabel(dimY, 1).addCell(15);
- assertTensor({{"x",3},{"y",2}}, {10,11,12,13,14,15},
- *f.builder.build());
+ assertTensor({{"x",3},{"y",2}}, {10,11,12,13,14,15}, *build2DTensor(f.builder));
+}
+
+TEST_F("require that 2d tensor can be converted to tensor spec", Fixture)
+{
+ assertTensorSpec(TensorSpec("tensor(x[3],y[2])").
+ add({{"x", 0},{"y", 0}}, 10).
+ add({{"x", 0},{"y", 1}}, 11).
+ add({{"x", 1},{"y", 0}}, 12).
+ add({{"x", 1},{"y", 1}}, 13).
+ add({{"x", 2},{"y", 0}}, 14).
+ add({{"x", 2},{"y", 1}}, 15),
+ *build2DTensor(f.builder));
}
TEST_F("require that 3d tensor can be constructed", Fixture)
@@ -189,7 +228,6 @@ TEST_F("require that already specified label throws exception", Fixture)
"Label for dimension 'x' is already specified with value '0'");
}
-
TEST_F("require that dimensions are sorted", Fixture)
{
Builder::Dimension dimY = f.builder.defineDimension("y", 3);
@@ -205,4 +243,9 @@ TEST_F("require that dimensions are sorted", Fixture)
EXPECT_EQUAL("tensor(x[5],y[3])", denseTensor.getType().to_spec());
}
+
+
+
+
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/vespalib/src/vespa/vespalib/eval/tensor_spec.cpp b/vespalib/src/vespa/vespalib/eval/tensor_spec.cpp
index 28cda1b2962..9e3cb573555 100644
--- a/vespalib/src/vespa/vespalib/eval/tensor_spec.cpp
+++ b/vespalib/src/vespa/vespalib/eval/tensor_spec.cpp
@@ -2,9 +2,20 @@
#include <vespa/fastos/fastos.h>
#include "tensor_spec.h"
+#include <iostream>
namespace vespalib {
namespace eval {
+std::ostream &operator<<(std::ostream &out, const TensorSpec::Label &label)
+{
+ if (label.is_indexed()) {
+ out << label.index;
+ } else {
+ out << label.name;
+ }
+ return out;
+}
+
} // namespace vespalib::eval
} // namespace vespalib
diff --git a/vespalib/src/vespa/vespalib/eval/tensor_spec.h b/vespalib/src/vespa/vespalib/eval/tensor_spec.h
index aff23a42832..8085242c58a 100644
--- a/vespalib/src/vespa/vespalib/eval/tensor_spec.h
+++ b/vespalib/src/vespa/vespalib/eval/tensor_spec.h
@@ -31,6 +31,10 @@ public:
}
return (name < rhs.name);
}
+ bool operator==(const Label &rhs) const {
+ return (index == rhs.index) &&
+ (name == rhs.name);
+ }
};
using Address = std::map<vespalib::string,Label>;
using Cells = std::map<Address,double>;
@@ -47,5 +51,7 @@ public:
const Cells &cells() const { return _cells; }
};
+std::ostream &operator<<(std::ostream &out, const TensorSpec::Label &label);
+
} // namespace vespalib::eval
} // namespace vespalib
diff --git a/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.cpp b/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.cpp
index 5a160329e79..73096843f78 100644
--- a/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.cpp
+++ b/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.cpp
@@ -11,6 +11,7 @@
#include <vespa/vespalib/tensor/tensor_visitor.h>
#include <sstream>
+using vespalib::eval::TensorSpec;
namespace vespalib {
namespace tensor {
@@ -323,6 +324,33 @@ DenseTensor::clone() const
return std::make_unique<DenseTensor>(_dimensionsMeta, _cells);
}
+namespace {
+
+void
+buildAddress(const DenseTensor::CellsIterator &itr, TensorSpec::Address &address)
+{
+ auto addressItr = itr.address().begin();
+ for (const auto &dim : itr.dimensions()) {
+ address.emplace(std::make_pair(dim.dimension(), TensorSpec::Label(*addressItr++)));
+ }
+ assert(addressItr == itr.address().end());
+}
+
+}
+
+TensorSpec
+DenseTensor::toSpec() const
+{
+ TensorSpec result(getType().to_spec());
+ TensorSpec::Address address;
+ for (CellsIterator itr(_dimensionsMeta, _cells); itr.valid(); itr.next()) {
+ buildAddress(itr, address);
+ result.add(address, itr.cell());
+ address.clear();
+ }
+ return result;
+}
+
void
DenseTensor::print(std::ostream &out) const
{
diff --git a/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.h b/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.h
index 73d9c26c408..c8a82bfe73e 100644
--- a/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.h
+++ b/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.h
@@ -69,6 +69,7 @@ public:
void next();
double cell() const { return _cells[_cellIdx]; }
const std::vector<size_t> &address() const { return _address; }
+ const DimensionsMeta &dimensions() const { return _dimensionsMeta; }
};
@@ -103,6 +104,7 @@ public:
virtual void print(std::ostream &out) const override;
virtual vespalib::string toString() const override;
virtual Tensor::UP clone() const override;
+ virtual eval::TensorSpec toSpec() const override;
virtual void accept(TensorVisitor &visitor) const override;
};
diff --git a/vespalib/src/vespa/vespalib/tensor/tensor.h b/vespalib/src/vespa/vespalib/tensor/tensor.h
index 4128a27d9a7..e713ba161eb 100644
--- a/vespalib/src/vespa/vespalib/tensor/tensor.h
+++ b/vespalib/src/vespa/vespalib/tensor/tensor.h
@@ -6,6 +6,7 @@
#include "tensor_address.h"
#include <vespa/vespalib/stllike/string.h>
#include <vespa/vespalib/eval/tensor.h>
+#include <vespa/vespalib/eval/tensor_spec.h>
#include <vespa/vespalib/eval/value_type.h>
namespace vespalib {
@@ -41,6 +42,9 @@ struct Tensor : public eval::Tensor
virtual void print(std::ostream &out) const = 0;
virtual vespalib::string toString() const = 0;
virtual Tensor::UP clone() const = 0;
+ virtual eval::TensorSpec toSpec() const {
+ return eval::TensorSpec(getType().to_spec());
+ }
virtual void accept(TensorVisitor &visitor) const = 0;
};
diff --git a/vespalib/src/vespa/vespalib/test/insertion_operators.h b/vespalib/src/vespa/vespalib/test/insertion_operators.h
index 8ed52062281..ac4fa3541e3 100644
--- a/vespalib/src/vespa/vespalib/test/insertion_operators.h
+++ b/vespalib/src/vespa/vespalib/test/insertion_operators.h
@@ -1,6 +1,7 @@
// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once
+#include <map>
#include <ostream>
#include <set>
#include <vector>
@@ -41,5 +42,22 @@ operator<<(std::ostream &os, const std::vector<T> &set)
return os;
}
+template <typename K, typename V>
+std::ostream &
+operator<<(std::ostream &os, const std::map<K, V> &map)
+{
+ os << "{";
+ bool first = true;
+ for (const auto &entry : map) {
+ if (!first) {
+ os << ",";
+ }
+ os << "{" << entry.first << "," << entry.second << "}";
+ first = false;
+ }
+ os << "}";
+ return os;
+}
+
} // namespace std