From 204eb2ebd8ba304810bd18538530ef6eeca9d5c9 Mon Sep 17 00:00:00 2001 From: Michal Wozniak Date: Sun, 16 Apr 2023 11:37:17 +0200 Subject: [PATCH] Improve Jax ML example to jit model updates --- .../job-examples/ml_training_jax/main.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/kubernetes/job-examples/ml_training_jax/main.py b/kubernetes/job-examples/ml_training_jax/main.py index d981625..22485ee 100644 --- a/kubernetes/job-examples/ml_training_jax/main.py +++ b/kubernetes/job-examples/ml_training_jax/main.py @@ -40,6 +40,7 @@ def __call__(self, x): return x model = MLP() +optimizer = optax.adam(learning_rate=0.001) # Define the training functions def cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray: @@ -47,7 +48,6 @@ def cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray: return -jnp.mean(jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1)) # Computes the loss for the given model using the provided inputs and labels. -@jax.jit def loss_fn(params: Any, inputs: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray: logits = model.apply({'params': params}, inputs) return cross_entropy_loss(logits, labels) @@ -57,11 +57,15 @@ def split(arr: jnp.ndarray) -> jnp.ndarray: n_local_devices = jax.local_device_count() return arr.reshape(n_local_devices, arr.shape[0] // n_local_devices, *arr.shape[1:]) -# Compute gradients on the given mini-batch +# Computes gradients on the given mini-batch, then averages gradients across devices and updates the model parameters +# jit annotation is optional as the function is only used in the pmap context @jax.jit -def avg_value_and_grad_fun(params: Any, inputs: jnp.ndarray, labels: jnp.ndarray) -> Any: +def update_fun(params: Any, optimizer_state: Any, inputs: jnp.ndarray, labels: jnp.ndarray) -> Any: loss, grads = jax.value_and_grad(loss_fn)(params, inputs, labels) - return jax.lax.pmean((loss, grads), axis_name='i') + avg_loss, avg_grads = jax.lax.pmean((loss, grads), axis_name='i') + updates, optimizer_state = optimizer.update(avg_grads, optimizer_state) + params = optax.apply_updates(params, updates) + return params, optimizer_state, avg_loss # Load the shard of the MNIST dataset corresponding to the rank def load_mnist_shard(split: str, rank: int, size: int) -> Tuple[tf.data.Dataset, tfds.core.DatasetInfo]: @@ -87,11 +91,11 @@ def run(rank, world_size): params = model.init(init_rng, jnp.ones(input_shape, jnp.float32))['params'] # Initialize the optimizer - optimizer = optax.adam(learning_rate=0.001) optimizer_state = optimizer.init(params) # Replicate parameters across the local devices on each host replicated_params = jax.tree_util.tree_map(lambda x: jnp.array([x] * jax.local_device_count()), params) + replicated_optimizer_state = jax.tree_util.tree_map(lambda x: jnp.array([x] * jax.local_device_count()), optimizer_state) # Train the model for epoch in range(NUM_EPOCHS): @@ -102,12 +106,9 @@ def run(rank, world_size): split_inputs = split(jnp.array(batch_inputs_np)) split_labels = split(jnp.array(batch_labels_np)) - avg_loss, avg_grads = jax.pmap(avg_value_and_grad_fun, axis_name="i")(replicated_params, split_inputs, split_labels) - - updates, optimizer_state = optimizer.update(avg_grads, optimizer_state) - replicated_params = optax.apply_updates(replicated_params, updates) + replicated_params, replicated_optimizer_state, replicated_avg_loss = jax.pmap(update_fun, axis_name="i")(replicated_params, replicated_optimizer_state, split_inputs, split_labels) - avg_epoch_loss += avg_loss[0] + avg_epoch_loss += replicated_avg_loss[0] logging.info(f"Rank {rank}: epoch: {epoch+1}/{NUM_EPOCHS}, batch: {batch_num+1}/{len(train_ds_shard_in_batches)}, batch size: {batch_inputs_np.shape[0] * world_size}, batch per host size: {batch_inputs_np.shape[0]}") avg_epoch_loss /= len(train_ds_shard_in_batches)