diff --git a/decode_hooks.go b/decode_hooks.go index 2523c6a..1f3c69d 100644 --- a/decode_hooks.go +++ b/decode_hooks.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/netip" + "net/url" "reflect" "strconv" "strings" @@ -176,6 +177,26 @@ func StringToTimeDurationHookFunc() DecodeHookFunc { } } +// StringToURLHookFunc returns a DecodeHookFunc that converts +// strings to *url.URL. +func StringToURLHookFunc() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}, + ) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(&url.URL{}) { + return data, nil + } + + // Convert it by parsing + return url.Parse(data.(string)) + } +} + // StringToIPHookFunc returns a DecodeHookFunc that converts // strings to net.IP func StringToIPHookFunc() DecodeHookFunc { diff --git a/decode_hooks_test.go b/decode_hooks_test.go index 0604ddd..6549ae0 100644 --- a/decode_hooks_test.go +++ b/decode_hooks_test.go @@ -6,6 +6,7 @@ import ( "math/big" "net" "net/netip" + "net/url" "reflect" "strings" "testing" @@ -286,6 +287,35 @@ func TestStringToTimeDurationHookFunc(t *testing.T) { } } +func TestStringToURLHookFunc(t *testing.T) { + f := StringToURLHookFunc() + + urlSample, _ := url.Parse("http://example.com") + urlValue := reflect.ValueOf(urlSample) + strValue := reflect.ValueOf("http://example.com") + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("http://example.com"), urlValue, urlSample, false}, + {reflect.ValueOf("http ://example.com"), urlValue, (*url.URL)(nil), true}, + {reflect.ValueOf("http://example.com"), strValue, "http://example.com", false}, + } + + for i, tc := range cases { + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected err %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +} + func TestStringToTimeHookFunc(t *testing.T) { strValue := reflect.ValueOf("5") timeValue := reflect.ValueOf(time.Time{})