Skip to content

Commit

Permalink
Added SimpleInput to pipeline inputs to hold generated data
Browse files Browse the repository at this point in the history
  • Loading branch information
lenhoanglnh committed Jan 17, 2024
1 parent f096f94 commit ba95325
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 205 deletions.
117 changes: 0 additions & 117 deletions solidago/src/solidago/dataset/__init__.py

This file was deleted.

84 changes: 0 additions & 84 deletions solidago/src/solidago/dataset/tournesol.py

This file was deleted.

6 changes: 3 additions & 3 deletions solidago/src/solidago/generative_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd

from solidago.dataset import SimpleDataset
from solidago.pipeline.inputs import SimpleInput

from .user_model import UserModel, SvdUserModel
from .vouch_model import VouchModel, ErdosRenyiVouchModel
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(
def __call__(
self, n_users: int, n_entities: int,
random_seed: Optional[int] = None
) -> SimpleDataset:
) -> SimpleInput:
""" Generates a random dataset
Inputs:
- n_users
Expand All @@ -70,5 +70,5 @@ def __call__(
scores, comparisons = self.engagement_model(users, true_scores)
logger.info(f"Generate comparisons using {self.comparison_model}")
comparisons = self.comparison_model(true_scores, comparisons)
return SimpleDataset(users, vouches, entities, true_scores, scores, comparisons)
return SimpleInput(users, vouches, entities, true_scores, scores, comparisons)

65 changes: 64 additions & 1 deletion solidago/src/solidago/pipeline/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def get_individual_scores(
) -> Optional[pd.DataFrame]:
raise NotImplementedError


class TournesolInputFromPublicDataset(TournesolInput):
def __init__(self, dataset_zip: Union[str, BinaryIO]):
if isinstance(dataset_zip, str) and (
Expand Down Expand Up @@ -137,3 +136,67 @@ def get_user_index(self, public_username: str) -> Optional[int]:
if len(rows) == 0:
return None
return rows.index[0]

class SimpleInput(TournesolInput):
def __init__(
self,
users: pd.DataFrame = None,
vouches: pd.DataFrame = None,
entities: pd.DataFrame = None,
true_scores: pd.DataFrame = None,
user_scores: pd.DataFrame = None,
comparisons: pd.DataFrame = None
):
def df(x, **kwargs):
if x is not None:
return x
dtypes = [(key, kwargs[key]) for key in kwargs]
return pd.DataFrame(np.empty(0, np.dtype(list(dtypes))))

self.users = df(users, user_id=int, public_username=str, trust_score= float)
self.users.index.name = "user_id"
self.vouches = df(vouches, voucher=int, vouchee=int, vouch=float)
self.entities = entities
self.true_scores = true_scores
self.user_scores = df(user_scores, user_id=int, entity_id=int, is_public=bool)
self.comparisons = df(comparisons,
user_id=int, score=float, week_date=str, entity_a=int, entity_b=int)

def get_comparisons(
self,
criteria: Optional[str] = None,
user_id: Optional[int] = None,
) -> pd.DataFrame:
dtf = self.comparisons.copy(deep=False)
if criteria is not None:
dtf = dtf[dtf.criteria == criteria]
if user_id is not None:
dtf = dtf[dtf.user_id == user_id]
dtf["weight"] = 1
return dtf[["user_id", "entity_a", "entity_b", "criteria", "score", "weight"]]

@cached_property
def ratings_properties(self):
user_entities_pairs = pd.Series(
iter(
set(self.comparisons.groupby(["user_id", "entity_a"]).indices.keys())
| set(self.comparisons.groupby(["user_id", "entity_b"]).indices.keys())
)
)
dtf = pd.DataFrame([*user_entities_pairs], columns=["user_id", "entity_id"])
dtf["is_public"] = True
dtf["trust_score"] = dtf["user_id"].map(self.users["trust_score"])
scaling_calibration_user_ids = (
dtf[dtf.trust_score > self.SCALING_CALIBRATION_MIN_TRUST_SCORE]["user_id"]
.value_counts(sort=True)[: self.MAX_SCALING_CALIBRATION_USERS]
.index
)
dtf["is_scaling_calibration_user"] = dtf["user_id"].isin(scaling_calibration_user_ids)
return dtf

def get_individual_scores(
self,
criteria: Optional[str] = None,
user_id: Optional[int] = None,
) -> Optional[pd.DataFrame]:
raise NotImplementedError

0 comments on commit ba95325

Please sign in to comment.