Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(sdk): consolidate http calls in one function #840

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions pkg/internal/httprequest/httprequest.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"fmt"
"io"
"net/http"
"slices"
"time"

"github.com/trustbloc/wallet-sdk/pkg/api"
Expand Down Expand Up @@ -42,11 +43,28 @@ func New(httpClient httpClient, metricsLogger api.MetricsLogger) *Request {
func (r *Request) Do(method, endpointURL, contentType string, body io.Reader,
event, parentEvent string, errorResponseHandler func(statusCode int, responseBody []byte) error,
) ([]byte, error) {
req, err := http.NewRequestWithContext(context.Background(), method, endpointURL, body)
return r.DoContext(context.Background(), method, endpointURL, contentType,
nil, body, event, parentEvent, nil, errorResponseHandler)
}

var defaultAcceptableStatuses = []int{http.StatusOK}

// DoContext is the same as Do, but also accept context and headers.
func (r *Request) DoContext(ctx context.Context, method, endpointURL, contentType string,
additionalHeaders http.Header, body io.Reader, event, parentEvent string, acceptableStatuses []int,
errorResponseHandler func(statusCode int, responseBody []byte) error,
) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, method, endpointURL, body)
if err != nil {
return nil, err
}

for header, values := range additionalHeaders {
for _, value := range values {
req.Header.Add(header, value)
}
}

if contentType != "" {
req.Header.Add("Content-Type", contentType)
}
Expand Down Expand Up @@ -79,9 +97,14 @@ func (r *Request) Do(method, endpointURL, contentType string, body io.Reader,
return nil, err
}

if resp.StatusCode != http.StatusOK {
statuses := acceptableStatuses
if statuses == nil {
statuses = defaultAcceptableStatuses
}

if !slices.Contains(statuses, resp.StatusCode) {
if errorResponseHandler == nil {
errorResponseHandler = genericErrorResponseHandler
errorResponseHandler = genericErrorResponseHandler(statuses)
}

return nil, errorResponseHandler(resp.StatusCode, respBytes)
Expand All @@ -106,8 +129,16 @@ func (r *Request) DoAndParse(method, endpointURL, contentType string, body io.Re
return json.Unmarshal(respBytes, response)
}

func genericErrorResponseHandler(statusCode int, respBytes []byte) error {
return fmt.Errorf(
"expected status code %d but got status code %d with response body %s instead",
http.StatusOK, statusCode, respBytes)
func genericErrorResponseHandler(expectedStatusCodes []int) func(statusCode int, respBytes []byte) error {
return func(statusCode int, respBytes []byte) error {
if len(expectedStatusCodes) == 1 {
return fmt.Errorf(
"expected status code %d but got status code %d with response body %s instead",
expectedStatusCodes[0], statusCode, respBytes)
}

return fmt.Errorf(
"expected status codes %v but got status code %d with response body %s instead",
expectedStatusCodes, statusCode, respBytes)
}
}
47 changes: 15 additions & 32 deletions pkg/oauth2/clientregistration.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@ package oauth2

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"

"github.com/trustbloc/wallet-sdk/pkg/internal/httprequest"
)

const (
newRegisterClientEventText = "Register client"
fetchRequestObjectEventText = "Fetch request object via an HTTP GET request to %s"
)

// RegisterClient registers a new client at the given registration endpoint.
Expand Down Expand Up @@ -55,39 +62,15 @@ func RegisterClient(registrationEndpoint string, clientMetadata *ClientMetadata,
}

func getRawResponse(requestBytes []byte, registrationEndpoint string, opts *opts) ([]byte, error) {
httpReq, err := http.NewRequest( //nolint: noctx // Timeout expected to be set in HTTP client already
http.MethodPost, registrationEndpoint, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}

httpReq.Header.Set("Content-Type", "application/json")

headers := http.Header{}
if opts.initialAccessBearerToken != "" {
httpReq.Header.Set("Authorization", "Bearer "+opts.initialAccessBearerToken)
}

resp, err := opts.httpClient.Do(httpReq)
if err != nil {
return nil, err
headers.Set("Authorization", "Bearer "+opts.initialAccessBearerToken)
}

defer func() {
errClose := resp.Body.Close()
if errClose != nil {
println(fmt.Sprintf("failed to close response body: %s", errClose.Error()))
}
}()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

if resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("server returned status code %d with body [%s]", resp.StatusCode,
string(respBody))
}
metricsEvent := fmt.Sprintf(fetchRequestObjectEventText, registrationEndpoint)

return respBody, nil
return httprequest.New(opts.httpClient, opts.metricsLogger).DoContext(context.TODO(),
http.MethodPost, registrationEndpoint, "application/json", headers,
bytes.NewReader(requestBytes), metricsEvent, newRegisterClientEventText,
[]int{http.StatusCreated}, nil)
}
2 changes: 1 addition & 1 deletion pkg/oauth2/clientregistration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func TestRegisterClient(t *testing.T) {
defer server.Close()

response, err := oauth2.RegisterClient(server.URL, nil)
require.EqualError(t, err, "server returned status code 500 with body []")
require.ErrorContains(t, err, "expected status code 201 but got status code 500 with response body instead")
require.Nil(t, response)
})
t.Run("Server returns empty body, resulting in a JSON unmarshal failure", func(t *testing.T) {
Expand Down
15 changes: 15 additions & 0 deletions pkg/oauth2/opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"net/http"

"github.com/trustbloc/wallet-sdk/pkg/api"
"github.com/trustbloc/wallet-sdk/pkg/metricslogger/noop"
)

type opts struct {
initialAccessBearerToken string
httpClient *http.Client
metricsLogger api.MetricsLogger
}

// An Opt is a single option for a call to RegisterClient.
Expand All @@ -29,6 +31,15 @@ func WithHTTPClient(httpClient *http.Client) Opt {
}
}

// WithMetricsLogger is an option for a call to RegisterClient that allows a caller to specify their MetricsLogger.
// If used, then performance metrics events will be pushed to the given MetricsLogger implementation.
// If this option is not used, then metrics logging will be disabled.
func WithMetricsLogger(metricsLogger api.MetricsLogger) Opt {
return func(opts *opts) {
opts.metricsLogger = metricsLogger
}
}

func processOpts(options []Opt) *opts {
opts := mergeOpts(options)

Expand All @@ -48,5 +59,9 @@ func mergeOpts(options []Opt) *opts {
}
}

if resolveOpts.metricsLogger == nil {
resolveOpts.metricsLogger = noop.NewMetricsLogger()
}

return resolveOpts
}
73 changes: 11 additions & 62 deletions pkg/openid4ci/interaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
Expand All @@ -33,6 +32,7 @@ import (
"github.com/trustbloc/wallet-sdk/pkg/common"
diderrors "github.com/trustbloc/wallet-sdk/pkg/did"
"github.com/trustbloc/wallet-sdk/pkg/did/wellknown"
"github.com/trustbloc/wallet-sdk/pkg/internal/httprequest"
metadatafetcher "github.com/trustbloc/wallet-sdk/pkg/internal/issuermetadata"
"github.com/trustbloc/wallet-sdk/pkg/models/issuer"
"github.com/trustbloc/wallet-sdk/pkg/walleterror"
Expand Down Expand Up @@ -380,19 +380,21 @@ func (i *interaction) getCredentialResponse(signer api.JWTSigner, nonce any,
oAuthHTTPClient := createOAuthHTTPClient(i.oAuth2Config, i.authToken, i.httpClient)

for index := range credentialTypes {
request, err := i.createCredentialRequestWithoutAccessToken(proofJWT, credentialFormats[index],
requestBody, err := i.createCredentialRequestBody(proofJWT, credentialFormats[index],
credentialTypes[index], credentialContexts[index])
if err != nil {
return nil, err
}

// The access token header will be injected automatically by the OAuth HTTP client, so there's no need to
// explicitly set it on the request object generated by the method call above.

fetchCredentialResponseEventText := fmt.Sprintf(fetchCredentialViaGETReqEventText, index+1,
len(credentialTypes), i.issuerMetadata.CredentialEndpoint)

responseBytes, err := i.getRawCredentialResponse(request, fetchCredentialResponseEventText, oAuthHTTPClient)
// The access token header will be injected automatically by the OAuth HTTP client, so there's no need to
// explicitly set it on the request object generated by the method call above.
responseBytes, err := httprequest.New(oAuthHTTPClient, i.metricsLogger).DoContext(context.TODO(),
http.MethodPost, i.issuerMetadata.CredentialEndpoint, "application/json", nil,
bytes.NewReader(requestBody), fetchCredentialResponseEventText, requestCredentialEventText,
[]int{http.StatusOK, http.StatusCreated}, processCredentialErrorResponse)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -461,12 +463,9 @@ func createOAuthHTTPClient(
return oAuthHTTPClient
}

// The returned *http.Request will not have the access token set on it. The caller must ensure that it's set
// before sending the request to the server.
func (i *interaction) createCredentialRequestWithoutAccessToken(proofJWT, credentialFormat string,
func (i *interaction) createCredentialRequestBody(proofJWT, credentialFormat string,
credentialTypes, credentialContext []string,
) (*http.Request, error) {

) ([]byte, error) {
var credentialContextToSend *[]string

if len(credentialContext) > 0 {
Expand All @@ -485,57 +484,7 @@ func (i *interaction) createCredentialRequestWithoutAccessToken(proofJWT, creden
},
}

credentialReqBytes, err := json.Marshal(credentialReq)
if err != nil {
return nil, err
}

request, err := http.NewRequest(http.MethodPost, //nolint: noctx
i.issuerMetadata.CredentialEndpoint, bytes.NewReader(credentialReqBytes))
if err != nil {
return nil, err
}

request.Header.Add("Content-Type", "application/json")

return request, nil
}

func (i *interaction) getRawCredentialResponse(credentialReq *http.Request, eventText string, httpClient *http.Client,
) ([]byte, error) {
timeStartHTTPRequest := time.Now()

response, err := httpClient.Do(credentialReq)
if err != nil {
return nil, err
}

err = i.metricsLogger.Log(&api.MetricsEvent{
Event: eventText,
ParentEvent: requestCredentialEventText,
Duration: time.Since(timeStartHTTPRequest),
})
if err != nil {
return nil, err
}

responseBytes, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}

if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusCreated {
return nil, processCredentialErrorResponse(response.StatusCode, responseBytes)
}

defer func() {
errClose := response.Body.Close()
if errClose != nil {
println(fmt.Sprintf("failed to close response body: %s", errClose.Error()))
}
}()

return responseBytes, nil
return json.Marshal(credentialReq)
}

func (i *interaction) getVCsFromCredentialResponses(
Expand Down
29 changes: 13 additions & 16 deletions pkg/openid4ci/issuerinitiatedinteraction.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,19 +443,22 @@ func (i *IssuerInitiatedInteraction) getCredentialResponse(
credentialResponses := make([]CredentialResponse, len(i.credentialTypes))

for index := range i.credentialTypes {
request, err := i.interaction.createCredentialRequestWithoutAccessToken(proofJWT, i.credentialFormats[index],
requestBody, err := i.interaction.createCredentialRequestBody(proofJWT, i.credentialFormats[index],
i.credentialTypes[index], i.credentialContexts[index])
if err != nil {
return nil, err
}

request.Header.Add("Authorization", "Bearer "+tokenResponse.AccessToken)
headers := http.Header{}
headers.Add("Authorization", "Bearer "+tokenResponse.AccessToken)

fetchCredentialResponseEventText := fmt.Sprintf(fetchCredentialViaGETReqEventText, index+1,
len(i.credentialTypes), i.interaction.issuerMetadata.CredentialEndpoint)

responseBytes, err := i.interaction.getRawCredentialResponse(request, fetchCredentialResponseEventText,
i.interaction.httpClient)
responseBytes, err := httprequest.New(i.interaction.httpClient, i.interaction.metricsLogger).DoContext(context.TODO(),
http.MethodPost, i.interaction.issuerMetadata.CredentialEndpoint, "application/json", headers,
bytes.NewReader(requestBody), fetchCredentialResponseEventText, requestCredentialEventText,
[]int{http.StatusOK, http.StatusCreated}, processCredentialErrorResponse)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -505,22 +508,16 @@ func (i *IssuerInitiatedInteraction) getCredentialResponsesBatch(
return nil, err
}

request, err := http.NewRequestWithContext(context.Background(),
http.MethodPost,
i.interaction.issuerMetadata.BatchCredentialEndpoint,
bytes.NewReader(b),
)
if err != nil {
return nil, err
}

request.Header.Add("Content-Type", "application/json")
request.Header.Add("Authorization", "Bearer "+tokenResponse.AccessToken)
headers := http.Header{}
headers.Add("Authorization", "Bearer "+tokenResponse.AccessToken)

fetchCredentialResponseEventText := fmt.Sprintf(fetchCredentialViaGETReqEventText, numberOfCredentials,
numberOfCredentials, i.interaction.issuerMetadata.BatchCredentialEndpoint)

b, err = i.interaction.getRawCredentialResponse(request, fetchCredentialResponseEventText, i.interaction.httpClient)
b, err = httprequest.New(i.interaction.httpClient, i.interaction.metricsLogger).DoContext(context.TODO(),
http.MethodPost, i.interaction.issuerMetadata.BatchCredentialEndpoint, "application/json", headers,
bytes.NewReader(b), fetchCredentialResponseEventText, requestCredentialEventText,
[]int{http.StatusOK, http.StatusCreated}, processCredentialErrorResponse)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading