Skip to content

Commit

Permalink
Merge pull request #87 from epfLLM/evalonly_and_wbresume
Browse files Browse the repository at this point in the history
Added eval only and wandb resume options
  • Loading branch information
AleHD authored Nov 6, 2023
2 parents 402f7e8 + 47e6f9c commit 820f102
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
8 changes: 6 additions & 2 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,10 +536,12 @@ def _add_logging_args(parser):
help='Project name for Weights & Biases.')
group.add_argument('--wandb_entity', type=str, default="meditron",
help='Entity/team name for Weights & Biases.')
group.add_argument('--wandb_name', type=str, default=None,
help='Name for this run, alternatively can set `WANDB_NAME`.')
group.add_argument('--wandb_id',type=str,default=None,
help="Unique ID to identify this run, alternatively can set `WANDB_RUN_ID`.")
group.add_argument('--wandb_resume',action="store_true",
help="If set, we resume logging for the id given instead of launching a new run (errors if id given and resume=False).")
group.add_argument('--wandb_resume',type=str,default="allow",
help="If set, we resume logging for the id given instead of launching a new run (errors if id given and resume=None).")
group.add_argument("--wandb_api_key",type=str,default=None,
help="API key for Weights & Biases, needs to be set if not set in environment variable `WANDB_API_KEY`.")
group.add_argument("--metrics", default=[], nargs="+", choices=list(METRICS) + ["all"],
Expand Down Expand Up @@ -878,6 +880,8 @@ def _add_distributed_args(parser):

def _add_validation_args(parser):
group = parser.add_argument_group(title='validation')
group.add_argument('--eval_only', action='store_true',
help='Run evaluation only.')
group.add_argument('--eval_iters', type=int, default=100,
help='Number of iterations to run for evaluation'
'validation/test for.')
Expand Down
5 changes: 5 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,11 @@ def build_train_valid_test_data_iterators(build_train_valid_test_datasets_provid
args.do_valid = flags[1].item()
args.do_test = flags[2].item()

if args.eval_only:
args.do_train = False
args.do_valid = False
args.do_test = True

# Build iterators.
dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic']
Expand Down
1 change: 1 addition & 0 deletions megatron/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def from_args(args)->'WandBConfig':
log_interval=args.log_interval,
config=args,entity=args.wandb_entity,
project=args.wandb_project,
name=args.wandb_name,
run_id=args.wandb_id,
resume=args.wandb_resume,
api_key=args.wandb_api_key,
Expand Down

0 comments on commit 820f102

Please sign in to comment.