diff options
author | Håvard Pettersen <havardpe@oath.com> | 2017-11-06 15:27:58 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2017-11-07 14:47:34 +0000 |
commit | 0970129d98a386753e2fa24c559c77392691c633 (patch) | |
tree | aea271d8b97ff24fb3f4020b09d26901b978ddba /eval/src/apps | |
parent | f5957dbf63a5fcd7df5df9062ef0324a52ed8605 (diff) |
clean up tensor engine API
make Tensor a subclass of Value
Diffstat (limited to 'eval/src/apps')
-rw-r--r-- | eval/src/apps/eval_expr/eval_expr.cpp | 2 | ||||
-rw-r--r-- | eval/src/apps/tensor_conformance/tensor_conformance.cpp | 51 |
2 files changed, 13 insertions, 40 deletions
diff --git a/eval/src/apps/eval_expr/eval_expr.cpp b/eval/src/apps/eval_expr/eval_expr.cpp index 2e1f7f7fdcb..71c808174b8 100644 --- a/eval/src/apps/eval_expr/eval_expr.cpp +++ b/eval/src/apps/eval_expr/eval_expr.cpp @@ -26,7 +26,7 @@ int main(int argc, char **argv) { if (result.is_double()) { fprintf(stdout, "%.32g\n", result.as_double()); } else if (result.is_tensor()) { - vespalib::string str = SimpleTensorEngine::ref().to_spec(*result.as_tensor()).to_string(); + vespalib::string str = SimpleTensorEngine::ref().to_spec(result).to_string(); fprintf(stdout, "%s\n", str.c_str()); } else { fprintf(stdout, "error\n"); diff --git a/eval/src/apps/tensor_conformance/tensor_conformance.cpp b/eval/src/apps/tensor_conformance/tensor_conformance.cpp index d1163fb579d..616b98f0809 100644 --- a/eval/src/apps/tensor_conformance/tensor_conformance.cpp +++ b/eval/src/apps/tensor_conformance/tensor_conformance.cpp @@ -60,69 +60,42 @@ nbostream extract_data(const Inspector &value) { //----------------------------------------------------------------------------- -TensorSpec to_spec(const Value &value) { - if (value.is_error()) { - return TensorSpec("error"); - } else if (value.is_double()) { - return TensorSpec("double").add({}, value.as_double()); - } else { - ASSERT_TRUE(value.is_tensor()); - auto tensor = value.as_tensor(); - return tensor->engine().to_spec(*tensor); - } -} - -const Value &to_value(const TensorSpec &spec, const TensorEngine &engine, Stash &stash) { - if (spec.type() == "error") { - return stash.create<ErrorValue>(); - } else if (spec.type() == "double") { - double value = 0.0; - for (const auto &cell: spec.cells()) { - value += cell.second; - } - return stash.create<DoubleValue>(value); - } else { - ASSERT_TRUE(starts_with(spec.type(), "tensor(")); - return stash.create<TensorValue>(engine.create(spec)); - } -} - void insert_value(Cursor &cursor, const vespalib::string &name, const TensorSpec &spec) { - Stash stash; nbostream data; - const Value &value = to_value(spec, SimpleTensorEngine::ref(), stash); - SimpleTensorEngine::ref().encode(value, data, stash); + Value::UP value = SimpleTensorEngine::ref().from_spec(spec); + SimpleTensorEngine::ref().encode(*value, data); cursor.setData(name, Memory(data.peek(), data.size())); } TensorSpec extract_value(const Inspector &inspector) { - Stash stash; nbostream data = extract_data(inspector); - return to_spec(SimpleTensorEngine::ref().decode(data, stash)); + const auto &engine = SimpleTensorEngine::ref(); + return engine.to_spec(*engine.decode(data)); } //----------------------------------------------------------------------------- -std::vector<ValueType> get_types(const std::vector<Value::CREF> ¶m_values) { +std::vector<ValueType> get_types(const std::vector<Value::UP> ¶m_values) { std::vector<ValueType> param_types; for (size_t i = 0; i < param_values.size(); ++i) { - param_types.emplace_back(param_values[i].get().type()); + param_types.emplace_back(param_values[i]->type()); } return param_types; } TensorSpec eval_expr(const Inspector &test, const TensorEngine &engine, bool typed) { - Stash stash; Function fun = Function::parse(test["expression"].asString().make_string()); - std::vector<Value::CREF> param_values; + std::vector<Value::UP> param_values; + std::vector<Value::CREF> param_refs; for (size_t i = 0; i < fun.num_params(); ++i) { - param_values.emplace_back(to_value(extract_value(test["inputs"][fun.param_name(i)]), engine, stash)); + param_values.emplace_back(engine.from_spec(extract_value(test["inputs"][fun.param_name(i)]))); + param_refs.emplace_back(*param_values.back()); } NodeTypes types = typed ? NodeTypes(fun, get_types(param_values)) : NodeTypes(); InterpretedFunction ifun(engine, fun, types); InterpretedFunction::Context ctx(ifun); - InterpretedFunction::SimpleObjectParams params(param_values); - return to_spec(ifun.eval(ctx, params)); + InterpretedFunction::SimpleObjectParams params(param_refs); + return engine.to_spec(ifun.eval(ctx, params)); } //----------------------------------------------------------------------------- |