Skip to content

Commit

Permalink
Merge pull request #52 from WojciechMigda/regressor-missing-api
Browse files Browse the repository at this point in the history
Add missing regressor api, needed for PyTsetlini
  • Loading branch information
WojciechMigda authored Jul 9, 2020
2 parents e1eb094 + 0bb9a5f commit 68c0a0d
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lib/include/tsetlini_state_json.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ namespace Tsetlini


struct ClassifierState;
struct RegressorState;

std::string to_json_string(ClassifierState const & state);
void from_json_string(ClassifierState & state, std::string const & js);

std::string to_json_string(RegressorState const & state);
void from_json_string(RegressorState & state, std::string const & js);


} // namespace Tsetlini
16 changes: 16 additions & 0 deletions lib/src/tsetlini_private.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,20 @@ fit_impl(
unsigned int epochs);


[[nodiscard]]
status_message_t
partial_fit_impl(
RegressorState & state,
std::vector<aligned_vector_char> const & X,
response_vector_type const & y,
unsigned int epochs);


[[nodiscard]]
neither::Either<status_message_t, response_vector_type>
predict_impl(
RegressorState const & state,
std::vector<aligned_vector_char> const & X);


} // namespace Tsetlini
42 changes: 42 additions & 0 deletions lib/src/tsetlini_state_json.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,46 @@ void from_json_string(ClassifierState & state, std::string const & jss)
}


std::string to_json_string(RegressorState const & state)
{
json js;

js["ta_state"] = state.ta_state;
js["igen"] = state.igen;
js["fgen"] = state.fgen;
js["params"] = state.m_params;

return js.dump();
}


void from_json_string(RegressorState & state, std::string const & jss)
{
auto js = json::parse(jss);

state.igen = js.at("igen").get<IRNG>();
state.fgen = js.at("fgen").get<FRNG>();
state.ta_state = js.at("ta_state").get<Tsetlini::RegressorState::ta_state_v_type>();
state.m_params = js.at("params").get<params_t>();

// So, we need a hack, since stringified json doesn't distinguish
// between signed and unsigned types for integer values > 0.
// most of our params are signed integers, as wrapped
// by std::variant, except for "random_state".
// Json parser will report them as unsigned.
// Here we will re-cast those integers as signed (contrary to json
// enumeration).
for (auto & [k, v]: state.m_params)
{
if (std::holds_alternative<seed_type>(v) and k != "random_state")
{
state.m_params[k] = static_cast<int>(std::get<seed_type>(v));
}
}
// end-of-hack

reset_state_cache(state);
}


} // namespace Tsetlini

0 comments on commit 68c0a0d

Please sign in to comment.