Skip to content

Commit

Permalink
Merge pull request #5 from mimowo/job-examples-jax-improve
Browse files Browse the repository at this point in the history
Improve Jax ML example to jit model updates
  • Loading branch information
mimowo authored Apr 16, 2023
2 parents e4473ba + 204eb2e commit c4c10df
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions kubernetes/job-examples/ml_training_jax/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def __call__(self, x):
return x

model = MLP()
optimizer = optax.adam(learning_rate=0.001)

# Define the training functions
def cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
one_hot_labels = jax.nn.one_hot(labels, logits.shape[-1])
return -jnp.mean(jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1))

# Computes the loss for the given model using the provided inputs and labels.
@jax.jit
def loss_fn(params: Any, inputs: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
logits = model.apply({'params': params}, inputs)
return cross_entropy_loss(logits, labels)
Expand All @@ -57,11 +57,15 @@ def split(arr: jnp.ndarray) -> jnp.ndarray:
n_local_devices = jax.local_device_count()
return arr.reshape(n_local_devices, arr.shape[0] // n_local_devices, *arr.shape[1:])

# Compute gradients on the given mini-batch
# Computes gradients on the given mini-batch, then averages gradients across devices and updates the model parameters
# jit annotation is optional as the function is only used in the pmap context
@jax.jit
def avg_value_and_grad_fun(params: Any, inputs: jnp.ndarray, labels: jnp.ndarray) -> Any:
def update_fun(params: Any, optimizer_state: Any, inputs: jnp.ndarray, labels: jnp.ndarray) -> Any:
loss, grads = jax.value_and_grad(loss_fn)(params, inputs, labels)
return jax.lax.pmean((loss, grads), axis_name='i')
avg_loss, avg_grads = jax.lax.pmean((loss, grads), axis_name='i')
updates, optimizer_state = optimizer.update(avg_grads, optimizer_state)
params = optax.apply_updates(params, updates)
return params, optimizer_state, avg_loss

# Load the shard of the MNIST dataset corresponding to the rank
def load_mnist_shard(split: str, rank: int, size: int) -> Tuple[tf.data.Dataset, tfds.core.DatasetInfo]:
Expand All @@ -87,11 +91,11 @@ def run(rank, world_size):
params = model.init(init_rng, jnp.ones(input_shape, jnp.float32))['params']

# Initialize the optimizer
optimizer = optax.adam(learning_rate=0.001)
optimizer_state = optimizer.init(params)

# Replicate parameters across the local devices on each host
replicated_params = jax.tree_util.tree_map(lambda x: jnp.array([x] * jax.local_device_count()), params)
replicated_optimizer_state = jax.tree_util.tree_map(lambda x: jnp.array([x] * jax.local_device_count()), optimizer_state)

# Train the model
for epoch in range(NUM_EPOCHS):
Expand All @@ -102,12 +106,9 @@ def run(rank, world_size):
split_inputs = split(jnp.array(batch_inputs_np))
split_labels = split(jnp.array(batch_labels_np))

avg_loss, avg_grads = jax.pmap(avg_value_and_grad_fun, axis_name="i")(replicated_params, split_inputs, split_labels)

updates, optimizer_state = optimizer.update(avg_grads, optimizer_state)
replicated_params = optax.apply_updates(replicated_params, updates)
replicated_params, replicated_optimizer_state, replicated_avg_loss = jax.pmap(update_fun, axis_name="i")(replicated_params, replicated_optimizer_state, split_inputs, split_labels)

avg_epoch_loss += avg_loss[0]
avg_epoch_loss += replicated_avg_loss[0]
logging.info(f"Rank {rank}: epoch: {epoch+1}/{NUM_EPOCHS}, batch: {batch_num+1}/{len(train_ds_shard_in_batches)}, batch size: {batch_inputs_np.shape[0] * world_size}, batch per host size: {batch_inputs_np.shape[0]}")

avg_epoch_loss /= len(train_ds_shard_in_batches)
Expand Down

0 comments on commit c4c10df

Please sign in to comment.