Skip to content

Commit

Permalink
Add linear model statistic with covariate adjustment (#2043)
Browse files Browse the repository at this point in the history
Co-authored-by: Mike Williams <102263964+mikewilli@users.noreply.github.com>
  • Loading branch information
danielkberry and mikewilli authored May 21, 2024
1 parent 1ba629d commit 7b339d3
Show file tree
Hide file tree
Showing 10 changed files with 975 additions and 66 deletions.
135 changes: 125 additions & 10 deletions jetstream/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,24 +384,64 @@ def subset_metric_table(
self,
metrics_table_name: str,
segment: str,
metric: Metric,
summary: metric.Summary,
analysis_basis: AnalysisBasis,
period: AnalysisPeriod,
) -> DataFrame:
"""Pulls the metric data for this segment/analysis basis"""

query = self._create_subset_metric_table_query(
metrics_table_name, segment, metric, analysis_basis
metrics_table_name, segment, summary, analysis_basis, period
)

results = self.bigquery.execute(query).to_dataframe()

return results

@staticmethod
def _create_subset_metric_table_query(
metrics_table_name: str, segment: str, metric: Metric, analysis_basis: AnalysisBasis
self,
metrics_table_name: str,
segment: str,
summary: metric.Summary,
analysis_basis: AnalysisBasis,
period: AnalysisPeriod,
) -> str:
query = ""
if covariate_params := summary.statistic.params.get("covariate_adjustment", False):
covariate_metric_name = covariate_params.get("metric", summary.metric.name)
covariate_period = AnalysisPeriod(covariate_params["period"])
if covariate_period != period:
# when we configure a metric, all statistics are applied to all periods
# however, to perform covariate adjustment we must use data from a different
# period. So the metric will be configured with analysis periods like
# [preenrollment_week, weekly, overall] but covariate adjustment should
# only be applied on weekly and overall when using preenrollment_week
# as the covariate.
query = self._create_subset_metric_table_query_covariate(
metrics_table_name,
segment,
summary.metric,
analysis_basis,
covariate_period,
covariate_metric_name,
)

if not query:
query = self._create_subset_metric_table_query_univariate(
metrics_table_name, segment, summary.metric, analysis_basis
)

return query

def _create_subset_metric_table_query_univariate(
self,
metrics_table_name: str,
segment: str,
metric: Metric,
analysis_basis: AnalysisBasis,
) -> str:
"""Creates a SQL query string to pull a single metric for a segment/analysis-"""
"""Creates a SQL query string to pull a single metric for a segment/analysis"""

metric_names = []
# select placeholder column for metrics without select statement
# since metrics that don't appear in the df are skipped
Expand Down Expand Up @@ -443,6 +483,81 @@ def _create_subset_metric_table_query(

return query

def _create_subset_metric_table_query_covariate(
self,
metrics_table_name: str,
segment: str,
metric: Metric,
analysis_basis: AnalysisBasis,
covariate_period: AnalysisPeriod,
covariate_metric_name: str,
) -> str:
"""Creates a SQL query string to pull a during-experiment metric and join on a
pre-enrollment covariate for a segment/analysis"""

if metric.depends_on:
raise ValueError(
"metrics with dependencies are not currently supported for covariate adjustment"
)

covariate_table_name = self._table_name(
covariate_period.value, 1, analysis_basis=AnalysisBasis.ENROLLMENTS
)

if not self.bigquery.table_exists(covariate_table_name):
logger.error(
(
f"Covariate adjustment table {covariate_table_name} does not exist, "
"falling back to unadjusted inferences"
)
)
return self._create_subset_metric_table_query_univariate(
metrics_table_name, segment, metric, analysis_basis
)

preenrollment_metric_select = f"pre.{covariate_metric_name} AS {covariate_metric_name}_pre"
from_expression = dedent(
f"""{metrics_table_name} during
LEFT JOIN {covariate_table_name} pre
USING (client_id, branch)"""
)

query = dedent(
f"""
SELECT
during.branch,
during.{metric.name},
{preenrollment_metric_select}
FROM (
{from_expression}
)
WHERE during.{metric.name} IS NOT NULL AND
"""
)

if analysis_basis == AnalysisBasis.ENROLLMENTS:
basis_filter = """during.enrollment_date IS NOT NULL"""
elif analysis_basis == AnalysisBasis.EXPOSURES:
basis_filter = (
"""during.enrollment_date IS NOT NULL AND during.exposure_date IS NOT NULL"""
)
else:
raise ValueError(
f"AnalysisBasis {analysis_basis} not valid"
+ f"Allowed values are: {[AnalysisBasis.ENROLLMENTS, AnalysisBasis.EXPOSURES]}"
)

query += basis_filter

if segment != "all":
segment_filter = dedent(
f"""
AND during.{segment} = TRUE"""
)
query += segment_filter

return query

def check_runnable(self, current_date: Optional[datetime] = None) -> bool:
if self.config.experiment.normandy_slug is None:
# some experiments do not have a normandy slug
Expand Down Expand Up @@ -746,15 +861,15 @@ def run(self, current_date: datetime, dry_run: bool = False) -> None:

segment_labels = ["all"] + [s.name for s in self.config.experiment.segments]
for segment in segment_labels:
for m in self.config.metrics[period]:
for summary in self.config.metrics[period]:
if (
m.metric.analysis_bases != analysis_basis
and analysis_basis not in m.metric.analysis_bases
summary.metric.analysis_bases != analysis_basis
and analysis_basis not in summary.metric.analysis_bases
):
continue

segment_data = self.subset_metric_table(
metrics_table, segment, m.metric, analysis_basis
metrics_table, segment, summary, analysis_basis, period
)

analysis_length_dates = 1
Expand All @@ -764,7 +879,7 @@ def run(self, current_date: datetime, dry_run: bool = False) -> None:
analysis_length_dates = 7

segment_results.__root__ += self.calculate_statistics(
m,
summary,
segment_data,
segment,
analysis_basis,
Expand Down
10 changes: 10 additions & 0 deletions jetstream/bigquery_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import pandas as pd
from google.cloud.bigquery_storage import BigQueryReadClient
from google.cloud.exceptions import NotFound
from metric_config_parser.metric import AnalysisPeriod
from pytz import UTC

Expand Down Expand Up @@ -56,6 +57,15 @@ def _current_timestamp_label(self) -> str:
"""Returns the current UTC timestamp as a valid BigQuery label."""
return str(int(time.time()))

def table_exists(self, table_name: str) -> bool:
table_ref = self.client.dataset(self.dataset).table(table_name)
try:
self.client.get_table(table_ref)
except NotFound:
return False

return True

def load_table_from_json(
self, results: Iterable[Dict], table: str, job_config: google.cloud.bigquery.LoadJobConfig
):
Expand Down
55 changes: 55 additions & 0 deletions jetstream/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import mozanalysis.bayesian_stats.bayesian_bootstrap
import mozanalysis.bayesian_stats.binary
import mozanalysis.frequentist_stats.bootstrap
import mozanalysis.frequentist_stats.linear_models
import mozanalysis.metrics
import numpy as np
from google.cloud import bigquery
Expand Down Expand Up @@ -422,6 +423,60 @@ def transform(
)


@attr.s(auto_attribs=True)
class LinearModelMean(Statistic):
drop_highest: float = attr.field(default=0.005, validator=attr.validators.instance_of(float))
# currently used keys are "metric" as the name of the metric
# and "period" as the (preenrollment) period to pull from
covariate_adjustment: dict[str, str] | None = attr.field(default=None)

@covariate_adjustment.validator
def check(self, attribute, value):
if value is not None:
covariate_period = parser_metric.AnalysisPeriod(value["period"])
preenrollment_periods = [
parser_metric.AnalysisPeriod.PREENROLLMENT_WEEK,
parser_metric.AnalysisPeriod.PREENROLLMENT_DAYS_28,
]
if covariate_period not in preenrollment_periods:
raise ValueError(
"Covariate adjustment must be done using a pre-treatment analysis "
f"period (one of: {[p.value for p in preenrollment_periods]})"
)

def transform(
self,
df: DataFrame,
metric: str,
reference_branch: str,
experiment: Experiment,
analysis_basis: AnalysisBasis,
segment: str,
) -> StatisticResultCollection:

if self.covariate_adjustment is not None:
covariate_col_label = f"{self.covariate_adjustment.get('metric', metric)}_pre"
else:
covariate_col_label = None

ma_result = mozanalysis.frequentist_stats.linear_models.compare_branches_lm(
df,
col_label=metric,
ref_branch_label=reference_branch,
covariate_col_label=covariate_col_label,
threshold_quantile=1 - self.drop_highest,
alphas=[0.05],
)

return flatten_simple_compare_branches_result(
ma_result=ma_result,
metric_name=metric,
statistic_name="mean_lm",
reference_branch=reference_branch,
ci_width=0.95,
)


@attr.s(auto_attribs=True)
class PerClientDAUImpact(BootstrapMean):
drop_highest: float = 0.0
Expand Down
8 changes: 6 additions & 2 deletions jetstream/tests/integration/test_analysis_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,8 @@ def test_subset_metric_table(

stat = Statistic(name="bootstrap_mean", params={})

summary = Summary(test_active_hours, stat)

config.metrics = {AnalysisPeriod.WEEK: [Summary(test_active_hours, stat)]}

self.analysis_mock_run(monkeypatch, config, static_dataset, temporary_dataset, project_id)
Expand All @@ -976,8 +978,9 @@ def test_subset_metric_table(
analysis.subset_metric_table(
"test_experiment_exposures_week_1",
"all",
test_active_hours,
summary,
AnalysisBasis.EXPOSURES,
AnalysisPeriod.WEEK,
)
.compute()
.sort_values("branch")
Expand All @@ -990,8 +993,9 @@ def test_subset_metric_table(
analysis.subset_metric_table(
"test_experiment_enrollments_week_1",
"all",
test_active_hours,
summary,
AnalysisBasis.ENROLLMENTS,
AnalysisPeriod.WEEK,
)
.compute()
.sort_values("branch")
Expand Down
5 changes: 5 additions & 0 deletions jetstream/tests/integration/test_bigquery_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def test_tables_matching_regex(self, client, temporary_dataset):
assert client.tables_matching_regex("^enrollments_.*$") == ["enrollments_test_experiment"]
assert client.tables_matching_regex("nothing") == []

def test_table_exists(self, client, temporary_dataset):
assert client.table_exists("dummy_table") is False
client.client.create_table(f"{temporary_dataset}.dummy_table")
assert client.table_exists("dummy_table") is True

def test_touch_tables(self, client, temporary_dataset):
client.client.create_table(f"{temporary_dataset}.enrollments_test_experiment")
client.client.create_table(f"{temporary_dataset}.statistics_test_experiment_week_0")
Expand Down
Loading

0 comments on commit 7b339d3

Please sign in to comment.