Skip to content

Commit

Permalink
Use new train.report API (#49)
Browse files Browse the repository at this point in the history
We are converging on using train.report throughout the Ray library code base instead of tune.report.

See ray-project/xgboost_ray#292

Signed-off-by: Kai Fricke <kai@anyscale.com>
  • Loading branch information
krfricke authored Aug 24, 2023
1 parent 60a4e41 commit 527f31b
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 47 deletions.
2 changes: 1 addition & 1 deletion lightgbm_ray/examples/simple_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def main(cpus_per_actor, num_actors, num_samples):

# Load the best model checkpoint.
best_bst = lightgbm_ray.tune.load_model(
os.path.join(analysis.best_logdir, "tuned.lgbm")
os.path.join(analysis.best_trial.local_path, "tuned.lgbm")
)

best_bst.save_model("best_model.lgbm")
Expand Down
2 changes: 1 addition & 1 deletion lightgbm_ray/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _save_internal_checkpoint_callback() -> Callable:
def _callback(env: CallbackEnv) -> None:
if not is_rank_0:
return
if (
if this.checkpoint_frequency > 0 and (
env.iteration == env.end_iteration - 1
or env.iteration % this.checkpoint_frequency == 0
):
Expand Down
9 changes: 7 additions & 2 deletions lightgbm_ray/tests/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,13 @@ def testReplaceTuneCheckpoints(self):

replaced = in_dict["callbacks"][0]
self.assertTrue(isinstance(replaced, TuneReportCheckpointCallback))
self.assertSequenceEqual(replaced._report._metrics, ["met"])
self.assertEqual(replaced._checkpoint._filename, "test")

if getattr(replaced, "_report", None):
self.assertSequenceEqual(replaced._report._metrics, ["met"])
self.assertEqual(replaced._checkpoint._filename, "test")
else:
self.assertSequenceEqual(replaced._metrics, ["met"])
self.assertEqual(replaced._filename, "test")

def testEndToEndCheckpointing(self):
ray.init(num_cpus=4)
Expand Down
115 changes: 72 additions & 43 deletions lightgbm_ray/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import ray
from lightgbm.basic import Booster
from lightgbm.callback import CallbackEnv
from ray.train._internal.session import get_session
from ray.util.annotations import PublicAPI
from xgboost_ray.session import put_queue
from xgboost_ray.util import force_on_current_node

try:
from ray import tune
from ray import train, tune
from ray.tune import is_session_enabled
from ray.tune.integration.lightgbm import (
TuneReportCallback as OrigTuneReportCallback,
Expand Down Expand Up @@ -49,49 +50,68 @@ def is_rank_0(self, val: bool):


if TUNE_INSTALLED:

class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback):
def __call__(self, env: CallbackEnv) -> None:
if not self.is_rank_0:
return
eval_result = self._get_eval_result(env)
report_dict = self._get_report_dict(eval_result)
put_queue(lambda: tune.report(**report_dict))

class _TuneCheckpointCallback(_TuneLGBMRank0Mixin, _OrigTuneCheckpointCallback):
def __call__(self, env: CallbackEnv) -> None:
if not self.is_rank_0:
return
put_queue(
lambda: self._create_checkpoint(
env.model, env.iteration, self._filename, self._frequency
if not hasattr(train, "report"):

class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback):
def __call__(self, env: CallbackEnv) -> None:
if not self.is_rank_0:
return
eval_result = self._get_eval_result(env)
report_dict = self._get_report_dict(eval_result)
put_queue(lambda: tune.report(**report_dict))

class _TuneCheckpointCallback(_TuneLGBMRank0Mixin, _OrigTuneCheckpointCallback):
def __call__(self, env: CallbackEnv) -> None:
if not self.is_rank_0:
return
put_queue(
lambda: self._create_checkpoint(
env.model, env.iteration, self._filename, self._frequency
)
)
)

class TuneReportCheckpointCallback(
_TuneLGBMRank0Mixin, OrigTuneReportCheckpointCallback
):
_checkpoint_callback_cls = _TuneCheckpointCallback
_report_callback_cls = TuneReportCallback

@property
def is_rank_0(self) -> bool:
try:
return self._is_rank_0
except AttributeError:
return False

@is_rank_0.setter
def is_rank_0(self, val: bool):
self._is_rank_0 = val
if hasattr(self, "_checkpoint"):
self._checkpoint.is_rank_0 = val
if hasattr(self, "_report"):
self._report.is_rank_0 = val

class TuneReportCheckpointCallback(
_TuneLGBMRank0Mixin, OrigTuneReportCheckpointCallback
):
_checkpoint_callback_cls = _TuneCheckpointCallback
_report_callback_cls = TuneReportCallback

@property
def is_rank_0(self) -> bool:
try:
return self._is_rank_0
except AttributeError:
return False

@is_rank_0.setter
def is_rank_0(self, val: bool):
self._is_rank_0 = val
if hasattr(self, "_checkpoint"):
self._checkpoint.is_rank_0 = val
if hasattr(self, "_report"):
self._report.is_rank_0 = val

else:

class TuneReportCheckpointCallback(
_TuneLGBMRank0Mixin, OrigTuneReportCheckpointCallback
):
def __call__(self, env: CallbackEnv):
if self.is_rank_0:
put_queue(
lambda: super(TuneReportCheckpointCallback, self).__call__(
env=env
)
)

class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback):
def __call__(self, env: CallbackEnv):
if self.is_rank_0:
put_queue(lambda: super(TuneReportCallback, self).__call__(env=env))


def _try_add_tune_callback(kwargs: Dict):
if TUNE_INSTALLED and is_session_enabled():
if TUNE_INSTALLED and is_session_enabled() or get_session():
callbacks = kwargs.get("callbacks", []) or []
new_callbacks = []
has_tune_callback = False
Expand All @@ -117,10 +137,19 @@ def _try_add_tune_callback(kwargs: Dict):
)
has_tune_callback = True
elif isinstance(cb, OrigTuneReportCheckpointCallback):
if getattr(cb, "_report", None):
orig_metrics = cb._report._metrics
orig_filename = cb._checkpoint._filename
orig_frequency = cb._checkpoint._frequency
else:
orig_metrics = cb._metrics
orig_filename = cb._filename
orig_frequency = cb._frequency

replace_cb = TuneReportCheckpointCallback(
metrics=cb._report._metrics,
filename=cb._checkpoint._filename,
frequency=cb._checkpoint._frequency,
metrics=orig_metrics,
filename=orig_filename,
frequency=orig_frequency,
)
new_callbacks.append(replace_cb)
logging.warning(
Expand Down

0 comments on commit 527f31b

Please sign in to comment.