Skip to content

Commit

Permalink
Add http3 CreateTCPStream test
Browse files Browse the repository at this point in the history
This includes a little bit of refactoring the docker container
initialization and making a testutils.NetConWrapper for re-use.
  • Loading branch information
max-b committed Apr 2, 2024
1 parent e463389 commit 150ecdf
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 18 deletions.
145 changes: 127 additions & 18 deletions http3/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@ package http3

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"log"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
Expand All @@ -15,13 +21,17 @@ import (
"github.com/stretchr/testify/require"
tc "github.com/testcontainers/testcontainers-go/modules/compose"
"github.com/testcontainers/testcontainers-go/wait"
"golang.org/x/net/http2"
)

const h2oServiceName string = "h2o"

func TestSimpleClientRequest(t *testing.T) {
var logger *slog.Logger
var containerGateway string

func TestMain(m *testing.M) {
level := slog.LevelDebug
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: level,
}))
slog.SetDefault(logger)
Expand All @@ -30,19 +40,22 @@ func TestSimpleClientRequest(t *testing.T) {
identifier := tc.StackIdentifier("h2o_test")
composeFile := fmt.Sprintf("%s/docker-compose.yml", testutils.RootDir())
compose, err := tc.NewDockerComposeWith(tc.WithStackFiles(composeFile), identifier)
require.NoError(t, err, "NewDockerComposeAPIWith()")
if err != nil {
log.Fatalf("error in NewDockerComposeAPIWith: %v", err)
}

t.Cleanup(func() {
require.NoError(t,
compose.Down(
context.Background(),
tc.RemoveOrphans(true),
tc.RemoveImagesLocal,
), "compose.Down()")
})
defer func() {
if err := compose.Down(
context.Background(),
tc.RemoveOrphans(true),
tc.RemoveImagesLocal,
); err != nil {
log.Fatalf("error in compose.Down: %v", err)
}
}()

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
defer cancel()

stack := compose.WaitForService(h2oServiceName,
// The h2o conf provides a /status endpoint listening on
Expand All @@ -53,26 +66,34 @@ func TestSimpleClientRequest(t *testing.T) {
WithStartupTimeout(10*time.Second),
)

err = stack.Up(ctx, tc.Wait(true))

require.NoError(t, err, "compose.Up()")
if err := stack.Up(ctx, tc.Wait(true)); err != nil {
log.Fatalf("error in compose.Up(): %v", err)
}

container, err := stack.ServiceContainer(ctx, h2oServiceName)
require.NoError(t, err, "fetch ServiceContainer")
if err != nil {
log.Fatalf("error in stack.ServiceContainer: %v", err)
}

logger.Info("compose up", "services", stack.Services(), "container", container)

// Kind of awkward network info parsing here.
// We need the container's gateway IP because that _should_ be the address the host can ListenUDP on where the container can access it.
containerIPs, err := container.ContainerIPs(ctx)
require.NoError(t, err, "container.ContainerIPs")
if err != nil {
log.Fatalf("error in container.ContainerIPs: %v", err)
}

containerIP := containerIPs[0]
containerIPSplit := strings.Split(containerIP, ".")
containerNet := strings.Join(containerIPSplit[:len(containerIPSplit)-1], ".")

containerGateway := fmt.Sprintf("%v.1", containerNet)
containerGateway = fmt.Sprintf("%v.1", containerNet)

m.Run()
}

func TestCreateUDPStream(t *testing.T) {
expectedRequest := "test udp request data\n"
expectedResponse := "test udp response data\n"

Expand Down Expand Up @@ -133,3 +154,91 @@ func TestSimpleClientRequest(t *testing.T) {
assert.Equal(t, expectedResponse, string(buf[0:n]), "Should receive correct UDP response")
logger.Info("got response", "buf", buf)
}

func TestCreateTCPStream(t *testing.T) {
// Start target HTTP/S server
expectedResponse := "test http response data"
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "", r.Header.Get("Proxy-Authorization"), "Should not pass the auth token to the end target server")
assert.Equal(t, http.MethodGet, r.Method, "Request to the end target server should be a GET")
fmt.Fprintf(w, expectedResponse)
}))
ts.EnableHTTP2 = true
// We want to listen on 0.0.0.0 because the proxy container will be on a different non-localhost network.
// In order to do that we have this kind of awkward hack borrowed from:
// https://stackoverflow.com/a/42218765/1787596
l, err := net.Listen("tcp", "0.0.0.0:0")
if err != nil {
log.Fatal(err)
}

// Swap out the default test server listener with our custom one listening on 0.0.0.0
ts.Listener.Close()
ts.Listener = l
ts.StartTLS()
defer ts.Close()

log.Printf("Test server listening on: %v", ts.URL)

urlSplit := strings.Split(ts.URL, ":")
port := urlSplit[len(urlSplit)-1]

certDataFile := fmt.Sprintf("%s/testdata/h2o/server.crt", testutils.RootDir())
certData, err := os.ReadFile(certDataFile)
require.NoError(t, err, "Reading certData")

config := ClientConfig{
ProxyAddr: "localhost:8444",
// The h2o server we're using doesn't require an actual token so this can be anything
AuthToken: "fake-token",
Logger: logger,
CertData: certData,
Insecure: true,
}

c, err := NewClient(config)
require.NoError(t, err, "NewClient")
defer c.Close()

// host.docker.internal is a docker specific host mapping for the h2o container that resolves to our localhost
dockerHostURL := fmt.Sprintf("%v:%v", containerGateway, port)
conn, err := c.CreateTCPStream(dockerHostURL)
require.NoError(t, err, "CreateTCPStream")
defer conn.Close()

req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://%v", dockerHostURL), nil)
require.NoError(t, err, "http.NewRequest")

certpool := x509.NewCertPool()
certpool.AddCert(ts.Certificate())
tlsDialWrapper := func(network, addr string, cfg *tls.Config) (net.Conn, error) {
tlsConf := &tls.Config{
RootCAs: certpool,
// It seems as though the httptest TLS server uses this arbitrarily as it's ServerName 🤷
ServerName: "example.com",
NextProtos: ts.TLS.NextProtos,
}
tlsClient := tls.Client(&testutils.NetConnWrapper{ReadWriteCloser: conn}, tlsConf)
err = tlsClient.Handshake()
return tlsClient, err
}

transport := &http2.Transport{
DialTLS: tlsDialWrapper,
}

httpClient := http.Client{
Transport: transport,
}
response, err := httpClient.Do(req)
require.NoError(t, err, "httpClient.Do")

defer response.Body.Close()
data, err := io.ReadAll(response.Body)
require.NoError(t, err, "io.ReadAll response body")

log.Printf("got response: %v", response)

assert.Equal(t, 200, response.StatusCode, "Should receive 200 response")
assert.Equal(t, expectedResponse, string(data), "Should receive expected body")
}
29 changes: 29 additions & 0 deletions internal/testutils/testutils.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package testutils

import (
"io"
"log"
"net"
"path"
"path/filepath"
"runtime"
"time"
)

func RootDir() string {
Expand All @@ -13,3 +16,29 @@ func RootDir() string {
d := path.Join(path.Dir(path.Dir(b)))
return filepath.Dir(d)
}

// We want an interface that can implement net.Conn so we need to add these methods
// But we do not expect them to be called during our tests
type NetConnWrapper struct {
io.ReadWriteCloser
}

func (r *NetConnWrapper) LocalAddr() net.Addr {
return nil
}

func (r *NetConnWrapper) RemoteAddr() net.Addr {
return nil
}

func (r *NetConnWrapper) SetDeadline(t time.Time) error {
return nil
}

func (r *NetConnWrapper) SetReadDeadline(t time.Time) error {
return nil
}

func (r *NetConnWrapper) SetWriteDeadline(t time.Time) error {
return nil
}

0 comments on commit 150ecdf

Please sign in to comment.