Skip to content

Commit

Permalink
fix(collector): Let pgx library parse TLS parameters (#1390)
Browse files Browse the repository at this point in the history
* fix(collector): Let pgx library parse TLS parameters

This allows the collector to respect the sslmode parameters

Fix: #1163

* Add comment

* Improve postgres collector test
  • Loading branch information
banjoh authored Nov 16, 2023
1 parent bc48568 commit d4623d9
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 25 deletions.
10 changes: 6 additions & 4 deletions pkg/collect/cluster_resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,16 @@ func (c *CollectClusterResources) Collect(progressChan chan<- interface{}) (Coll
var namespaceNames []string
if len(c.Collector.Namespaces) > 0 {
namespaces, namespaceErrors := getNamespaces(ctx, client, c.Collector.Namespaces)
klog.V(4).Infof("checking for namespaces access: %s", string(namespaces))
namespaceNames = c.Collector.Namespaces
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, fmt.Sprintf("%s.json", constants.CLUSTER_RESOURCES_NAMESPACES)), bytes.NewBuffer(namespaces))
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, fmt.Sprintf("%s-errors.json", constants.CLUSTER_RESOURCES_NAMESPACES)), marshalErrors(namespaceErrors))
} else if c.Namespace != "" {
namespace, namespaceErrors := getNamespace(ctx, client, c.Namespace)
klog.V(4).Infof("checking for namespace access: %s", string(namespace))
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, fmt.Sprintf("%s.json", constants.CLUSTER_RESOURCES_NAMESPACES)), bytes.NewBuffer(namespace))
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, fmt.Sprintf("%s-errors.json", constants.CLUSTER_RESOURCES_NAMESPACES)), marshalErrors(namespaceErrors))
namespaceNames = append(namespaceNames, c.Namespace)
} else {
namespaces, namespaceList, namespaceErrors := getAllNamespaces(ctx, client)
klog.V(4).Infof("checking for all namespaces access: %s", string(namespaces))
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, fmt.Sprintf("%s.json", constants.CLUSTER_RESOURCES_NAMESPACES)), bytes.NewBuffer(namespaces))
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, fmt.Sprintf("%s-errors.json", constants.CLUSTER_RESOURCES_NAMESPACES)), marshalErrors(namespaceErrors))
if namespaceList != nil {
Expand All @@ -146,6 +143,7 @@ func (c *CollectClusterResources) Collect(progressChan chan<- interface{}) (Coll
reviewStatuses, reviewStatusErrors := getSelfSubjectRulesReviews(ctx, client, namespaceNames)

// auth cani
klog.V(2).Infof("checking [%s] namespaces for permissions to collect resources", strings.Join(namespaceNames, ", "))
authCanI := authCanI(reviewStatuses, namespaceNames)
for k, v := range authCanI {
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, constants.CLUSTER_RESOURCES_AUTH_CANI, k), bytes.NewBuffer(v))
Expand All @@ -160,8 +158,12 @@ func (c *CollectClusterResources) Collect(progressChan chan<- interface{}) (Coll
filteredNamespaces = append(filteredNamespaces, ns)
}
}
if len(filteredNamespaces) != len(namespaceNames) {
klog.V(2).Infof("filtered namespaces down to [%s] after evaluating permissions", strings.Join(filteredNamespaces, ", "))
} else {
klog.V(2).Infof("no namespaces filtered out after evaluating permissions")
}
namespaceNames = filteredNamespaces
klog.V(4).Infof("filtered to namespaceNames %s", namespaceNames)
}

// pods
Expand Down
75 changes: 67 additions & 8 deletions pkg/collect/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"regexp"

"github.com/jackc/pgx/v5"
"github.com/pkg/errors"
troubleshootv1beta2 "github.com/replicatedhq/troubleshoot/pkg/apis/troubleshoot/v1beta2"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/klog/v2"
)

type CollectPostgres struct {
Expand All @@ -37,20 +40,74 @@ func (c *CollectPostgres) createConnectConfig() (*pgx.ConnConfig, error) {
return nil, errors.New("postgres uri cannot be empty")
}

cfg, err := pgx.ParseConfig(c.Collector.URI)
if err != nil {
return nil, errors.Wrap(err, "failed to parse postgres config")
}

if c.Collector.TLS != nil {
tlsCfg, err := createTLSConfig(c.Context, c.Client, c.Collector.TLS)
klog.V(2).Infof("Connecting to postgres with TLS client config")
// Set the libpq TLS environment variables since pgx parses them to
// create the TLS configuration (tls.Config instance) to connect with
// https://www.postgresql.org/docs/current/libpq-envars.html
caCert, clientCert, clientKey, err := getTLSParamTriplet(c.Context, c.Client, c.Collector.TLS)
if err != nil {
return nil, err
}

tlsCfg.ServerName = cfg.Host
cfg.TLSConfig = tlsCfg
// Drop the TLS params to files and set the paths to their
// respective environment variables
// The environment variables are unset after the connection config
// is created. Their respective files are deleted as well.
tmpdir, err := os.MkdirTemp("", "ts-postgres-collector")
if err != nil {
return nil, errors.Wrap(err, "failed to create temp dir to store postgres collector TLS files")
}
defer os.RemoveAll(tmpdir)

if caCert != "" {
caCertPath := filepath.Join(tmpdir, "ca.crt")
err = os.WriteFile(caCertPath, []byte(caCert), 0644)
if err != nil {
return nil, errors.Wrap(err, "failed to write ca cert to file")
}
err = os.Setenv("PGSSLROOTCERT", caCertPath)
if err != nil {
return nil, errors.Wrap(err, "failed to set PGSSLROOTCERT environment variable")
}
klog.V(2).Infof("'PGSSLROOTCERT' environment variable set to %q", caCertPath)
defer os.Unsetenv("PGSSLROOTCERT")
}

if clientCert != "" {
clientCertPath := filepath.Join(tmpdir, "client.crt")
err = os.WriteFile(clientCertPath, []byte(clientCert), 0644)
if err != nil {
return nil, errors.Wrap(err, "failed to write client cert to file")
}
err = os.Setenv("PGSSLCERT", clientCertPath)
if err != nil {
return nil, errors.Wrap(err, "failed to set PGSSLCERT environment variable")
}
klog.V(2).Infof("'PGSSLCERT' environment variable set to %q", clientCertPath)
defer os.Unsetenv("PGSSLCERT")
}

if clientKey != "" {
clientKeyPath := filepath.Join(tmpdir, "client.key")
err = os.WriteFile(clientKeyPath, []byte(clientKey), 0600)
if err != nil {
return nil, errors.Wrap(err, "failed to write client key to file")
}
err = os.Setenv("PGSSLKEY", clientKeyPath)
if err != nil {
return nil, errors.Wrap(err, "failed to set PGSSLKEY environment variable")
}
klog.V(2).Infof("'PGSSLKEY' environment variable set to %q", clientKeyPath)
defer os.Unsetenv("PGSSLKEY")
}
}

cfg, err := pgx.ParseConfig(c.Collector.URI)
if err != nil {
return nil, errors.Wrap(err, "failed to parse postgres config")
}
klog.V(2).Infof("Successfully parsed postgres config")

return cfg, nil
}
Expand All @@ -74,8 +131,10 @@ func (c *CollectPostgres) Collect(progressChan chan<- interface{}) (CollectorRes

conn, err := c.connect()
if err != nil {
klog.V(2).Infof("Postgres connection error: %s", err.Error())
databaseConnection.Error = err.Error()
} else {
klog.V(2).Info("Successfully connected to postgres")
defer conn.Close(c.Context)

query := `select version()`
Expand Down
21 changes: 19 additions & 2 deletions pkg/collect/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package collect

import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"testing"

"github.com/replicatedhq/troubleshoot/internal/testutils"
Expand Down Expand Up @@ -100,7 +103,7 @@ func TestCollectPostgres_createConnectConfigTLS(t *testing.T) {
Client: k8sClient,
Context: context.Background(),
Collector: &v1beta2.Database{
URI: "postgresql://user:password@my-pghost:5432/defaultdb?sslmode=require",
URI: "postgresql://user:password@my-pghost:5432/defaultdb?sslmode=verify-full",
TLS: &v1beta2.TLSParams{
CACert: testutils.GetTestFixture(t, "db/ca.pem"),
ClientCert: testutils.GetTestFixture(t, "db/client.pem"),
Expand All @@ -113,7 +116,21 @@ func TestCollectPostgres_createConnectConfigTLS(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, connCfg)
assert.Equal(t, connCfg.Host, "my-pghost")
assert.NotNil(t, connCfg.TLSConfig.Certificates)

// Check client cert
require.Len(t, connCfg.TLSConfig.Certificates, 1)
require.Len(t, connCfg.TLSConfig.Certificates[0].Certificate, 1)
cert := connCfg.TLSConfig.Certificates[0]
clientCert, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
assert.Equal(t, "CN=client,L=Didcot,ST=Oxfordshire,C=UK", clientCert.Subject.String())

// Check client key
block, _ := pem.Decode([]byte(testutils.GetTestFixture(t, "db/client-key.pem")))
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
require.NoError(t, err)
assert.True(t, key.Equal(cert.PrivateKey.(*rsa.PrivateKey)))

assert.NotNil(t, connCfg.TLSConfig.RootCAs)
assert.False(t, connCfg.TLSConfig.InsecureSkipVerify)
}
34 changes: 23 additions & 11 deletions pkg/collect/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,30 @@ func listNodesInSelector(ctx context.Context, client *kubernetes.Clientset, sele

nodes, err := client.CoreV1().Nodes().List(ctx, listOptions)
if err != nil {
return nil, fmt.Errorf("Can't get the list of nodes, got: %w", err)
return nil, fmt.Errorf("can't get the list of nodes, got: %w", err)
}

return nodes.Items, nil
}

func getTLSParamTriplet(
ctx context.Context, client kubernetes.Interface, params *troubleshootv1beta2.TLSParams,
) (string, string, string, error) {
var caCert, clientCert, clientKey string
if params.Secret != nil {
var err error
caCert, clientCert, clientKey, err = getTLSParamsFromSecret(ctx, client, params.Secret)
if err != nil {
return caCert, clientCert, clientKey, err
}
} else {
caCert = params.CACert
clientCert = params.ClientCert
clientKey = params.ClientKey
}
return caCert, clientCert, clientKey, nil
}

func createTLSConfig(ctx context.Context, client kubernetes.Interface, params *troubleshootv1beta2.TLSParams) (*tls.Config, error) {
rootCA, err := x509.SystemCertPool()
if err != nil {
Expand All @@ -158,21 +176,15 @@ func createTLSConfig(ctx context.Context, client kubernetes.Interface, params *t
return tlsCfg, nil
}

var caCert, clientCert, clientKey string
if params.Secret != nil {
caCert, clientCert, clientKey, err = getTLSParamsFromSecret(ctx, client, params.Secret)
if err != nil {
return nil, err
}
} else {
caCert = params.CACert
clientCert = params.ClientCert
clientKey = params.ClientKey
caCert, clientCert, clientKey, err := getTLSParamTriplet(ctx, client, params)
if err != nil {
return nil, err
}

if ok := rootCA.AppendCertsFromPEM([]byte(caCert)); !ok {
return nil, fmt.Errorf("failed to append CA cert to root CA bundle")
}

tlsCfg.RootCAs = rootCA

if clientCert == "" && clientKey == "" {
Expand Down

0 comments on commit d4623d9

Please sign in to comment.