diff --git a/github/actions/client.go b/github/actions/client.go index c31eefc81b..1afbab9c42 100644 --- a/github/actions/client.go +++ b/github/actions/client.go @@ -975,20 +975,38 @@ func (c *Client) getActionsServiceAdminConnection(ctx context.Context, rt *regis c.logger.Info("getting Actions tenant URL and JWT", "registrationURL", req.URL.String()) - resp, err := c.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() + var resp *http.Response + retry := 0 + for { + var err error + resp, err = c.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode > 299 { - registrationErr := fmt.Errorf("unexpected response from Actions service during registration call: %v", resp.StatusCode) + if resp.StatusCode >= 200 && resp.StatusCode <= 299 { + break + } + errStr := fmt.Sprintf("unexpected response from Actions service during registration call: %v", resp.StatusCode) body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("%v - %v", registrationErr, err) + err = fmt.Errorf("%s - %w", errStr, err) + } else { + err = fmt.Errorf("%s - %v", errStr, string(body)) + } + + if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusForbidden { + return nil, err + } + + retry++ + if retry > 3 { + return nil, fmt.Errorf("unable to register runner after 3 retries: %v", err) } - return nil, fmt.Errorf("%v - %v", registrationErr, string(body)) + time.Sleep(time.Duration(500 * int(time.Millisecond) * (retry + 1))) + } var actionsServiceAdminConnection *ActionsServiceAdminConnection diff --git a/github/actions/github_api_request_test.go b/github/actions/github_api_request_test.go index 2c744cfdc9..78f740ecf1 100644 --- a/github/actions/github_api_request_test.go +++ b/github/actions/github_api_request_test.go @@ -2,6 +2,7 @@ package actions_test import ( "context" + "encoding/json" "io" "net/http" "net/url" @@ -152,6 +153,43 @@ func TestNewActionsServiceRequest(t *testing.T) { assert.Equal(t, client.ActionsServiceAdminTokenExpiresAt, expiresAt) }) + t.Run("admin token refresh retry", func(t *testing.T) { + newToken := defaultActionsToken(t) + errMessage := `{"message":"test"}` + + srv := "http://github.com/my-org" + resp := &actions.ActionsServiceAdminConnection{ + AdminToken: &newToken, + ActionsServiceUrl: &srv, + } + failures := 0 + unauthorizedHandler := func(w http.ResponseWriter, r *http.Request) { + if failures < 2 { + failures++ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(errMessage)) + return + } + + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(resp) + } + server := testserver.New(t, nil, testserver.WithActionsToken("random-token"), testserver.WithActionsToken(newToken), testserver.WithActionsRegistrationTokenHandler(unauthorizedHandler)) + client, err := actions.NewClient(server.ConfigURLForOrg("my-org"), defaultCreds) + require.NoError(t, err) + expiringToken := "expiring-token" + expiresAt := time.Now().Add(59 * time.Second) + client.ActionsServiceAdminToken = expiringToken + client.ActionsServiceAdminTokenExpiresAt = expiresAt + + _, err = client.NewActionsServiceRequest(ctx, http.MethodGet, "my-path", nil) + require.NoError(t, err) + assert.Equal(t, client.ActionsServiceAdminToken, newToken) + assert.Equal(t, client.ActionsServiceURL, srv) + assert.NotEqual(t, client.ActionsServiceAdminTokenExpiresAt, expiresAt) + }) + t.Run("token is currently valid", func(t *testing.T) { tokenThatShouldNotBeFetched := defaultActionsToken(t) server := testserver.New(t, nil, testserver.WithActionsToken(tokenThatShouldNotBeFetched))