summaryrefslogtreecommitdiffstats
path: root/eval/src/apps
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2017-11-06 15:27:58 +0000
committerHåvard Pettersen <havardpe@oath.com>2017-11-07 14:47:34 +0000
commit0970129d98a386753e2fa24c559c77392691c633 (patch)
treeaea271d8b97ff24fb3f4020b09d26901b978ddba /eval/src/apps
parentf5957dbf63a5fcd7df5df9062ef0324a52ed8605 (diff)
clean up tensor engine API
make Tensor a subclass of Value
Diffstat (limited to 'eval/src/apps')
-rw-r--r--eval/src/apps/eval_expr/eval_expr.cpp2
-rw-r--r--eval/src/apps/tensor_conformance/tensor_conformance.cpp51
2 files changed, 13 insertions, 40 deletions
diff --git a/eval/src/apps/eval_expr/eval_expr.cpp b/eval/src/apps/eval_expr/eval_expr.cpp
index 2e1f7f7fdcb..71c808174b8 100644
--- a/eval/src/apps/eval_expr/eval_expr.cpp
+++ b/eval/src/apps/eval_expr/eval_expr.cpp
@@ -26,7 +26,7 @@ int main(int argc, char **argv) {
if (result.is_double()) {
fprintf(stdout, "%.32g\n", result.as_double());
} else if (result.is_tensor()) {
- vespalib::string str = SimpleTensorEngine::ref().to_spec(*result.as_tensor()).to_string();
+ vespalib::string str = SimpleTensorEngine::ref().to_spec(result).to_string();
fprintf(stdout, "%s\n", str.c_str());
} else {
fprintf(stdout, "error\n");
diff --git a/eval/src/apps/tensor_conformance/tensor_conformance.cpp b/eval/src/apps/tensor_conformance/tensor_conformance.cpp
index d1163fb579d..616b98f0809 100644
--- a/eval/src/apps/tensor_conformance/tensor_conformance.cpp
+++ b/eval/src/apps/tensor_conformance/tensor_conformance.cpp
@@ -60,69 +60,42 @@ nbostream extract_data(const Inspector &value) {
//-----------------------------------------------------------------------------
-TensorSpec to_spec(const Value &value) {
- if (value.is_error()) {
- return TensorSpec("error");
- } else if (value.is_double()) {
- return TensorSpec("double").add({}, value.as_double());
- } else {
- ASSERT_TRUE(value.is_tensor());
- auto tensor = value.as_tensor();
- return tensor->engine().to_spec(*tensor);
- }
-}
-
-const Value &to_value(const TensorSpec &spec, const TensorEngine &engine, Stash &stash) {
- if (spec.type() == "error") {
- return stash.create<ErrorValue>();
- } else if (spec.type() == "double") {
- double value = 0.0;
- for (const auto &cell: spec.cells()) {
- value += cell.second;
- }
- return stash.create<DoubleValue>(value);
- } else {
- ASSERT_TRUE(starts_with(spec.type(), "tensor("));
- return stash.create<TensorValue>(engine.create(spec));
- }
-}
-
void insert_value(Cursor &cursor, const vespalib::string &name, const TensorSpec &spec) {
- Stash stash;
nbostream data;
- const Value &value = to_value(spec, SimpleTensorEngine::ref(), stash);
- SimpleTensorEngine::ref().encode(value, data, stash);
+ Value::UP value = SimpleTensorEngine::ref().from_spec(spec);
+ SimpleTensorEngine::ref().encode(*value, data);
cursor.setData(name, Memory(data.peek(), data.size()));
}
TensorSpec extract_value(const Inspector &inspector) {
- Stash stash;
nbostream data = extract_data(inspector);
- return to_spec(SimpleTensorEngine::ref().decode(data, stash));
+ const auto &engine = SimpleTensorEngine::ref();
+ return engine.to_spec(*engine.decode(data));
}
//-----------------------------------------------------------------------------
-std::vector<ValueType> get_types(const std::vector<Value::CREF> &param_values) {
+std::vector<ValueType> get_types(const std::vector<Value::UP> &param_values) {
std::vector<ValueType> param_types;
for (size_t i = 0; i < param_values.size(); ++i) {
- param_types.emplace_back(param_values[i].get().type());
+ param_types.emplace_back(param_values[i]->type());
}
return param_types;
}
TensorSpec eval_expr(const Inspector &test, const TensorEngine &engine, bool typed) {
- Stash stash;
Function fun = Function::parse(test["expression"].asString().make_string());
- std::vector<Value::CREF> param_values;
+ std::vector<Value::UP> param_values;
+ std::vector<Value::CREF> param_refs;
for (size_t i = 0; i < fun.num_params(); ++i) {
- param_values.emplace_back(to_value(extract_value(test["inputs"][fun.param_name(i)]), engine, stash));
+ param_values.emplace_back(engine.from_spec(extract_value(test["inputs"][fun.param_name(i)])));
+ param_refs.emplace_back(*param_values.back());
}
NodeTypes types = typed ? NodeTypes(fun, get_types(param_values)) : NodeTypes();
InterpretedFunction ifun(engine, fun, types);
InterpretedFunction::Context ctx(ifun);
- InterpretedFunction::SimpleObjectParams params(param_values);
- return to_spec(ifun.eval(ctx, params));
+ InterpretedFunction::SimpleObjectParams params(param_refs);
+ return engine.to_spec(ifun.eval(ctx, params));
}
//-----------------------------------------------------------------------------