diff --git a/alpha.zip b/alpha.zip new file mode 100644 index 0000000000000000000000000000000000000000..70b24b4dc19c8a7e6b65c7ea5f9576022827a488 Binary files /dev/null and b/alpha.zip differ diff --git a/app.go b/app.go index 99fd2976bdf95df0de95ed81e1addbe13d54091a..db9344f05b6de86594384a07036946bca67ed04f 100644 --- a/app.go +++ b/app.go @@ -68,7 +68,15 @@ func (a *theApp) ServeTLS(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate } if domain, _ := a.domain(context.Background(), ch.ServerName); domain != nil { - tls, _ := domain.EnsureCertificate() + tls, err := domain.EnsureCertificate() + if err != nil { + log.WithFields(log.Fields{ + "pages_host": domain.Name, + "local_addr": ch.Conn.LocalAddr().String(), + "remote_addr": ch.Conn.RemoteAddr().String(), + }).WithError(err).Warn("failed to load certificate for custom domain") + } + return tls, nil } diff --git a/internal/config/tls/testdata/valid.crt b/internal/config/tls/testdata/valid.crt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/internal/config/tls/tls.go b/internal/config/tls/tls.go index 69065acf2c5e8f68acfe81ecff6d484f01aea1e3..a282a48028b520820c41a4f531e68f540e28a528 100644 --- a/internal/config/tls/tls.go +++ b/internal/config/tls/tls.go @@ -2,11 +2,17 @@ package tls import ( "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" "fmt" "sort" "strings" ) +// ErrEmptyCert is returned when decoding a PEM certificate returns an empty byte array +var ErrEmptyCert = errors.New("decode PEM certificate") + // GetCertificateFunc returns the certificate to be used for given domain type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) @@ -97,3 +103,32 @@ func configureTLSCiphers(tlsConfig *tls.Config) { tlsConfig.PreferServerCipherSuites = true tlsConfig.CipherSuites = preferredCipherSuites } + +// VerifyCert reads a certificate and verifies that the domainName +// belongs to the certificate as well as checking if it's expired +func VerifyCert(domainName string, certificate []byte) error { + if len(certificate) == 0 { + return nil + } + + block, _ := pem.Decode(certificate) + if block == nil { + return ErrEmptyCert + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return err + } + + opts := x509.VerifyOptions{ + DNSName: domainName, + //Roots: x509.NewCertPool(), + } + + if _, err := cert.Verify(opts); err != nil { + return err + } + + return nil +} diff --git a/internal/config/tls/tls_test.go b/internal/config/tls/tls_test.go index 06704a643ff376b05e5688de5d69f881705cc423..45b24cb7eb91941b2b8086bcf35c61b4ae39d4a0 100644 --- a/internal/config/tls/tls_test.go +++ b/internal/config/tls/tls_test.go @@ -1,8 +1,20 @@ package tls import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -69,3 +81,106 @@ func TestCreate(t *testing.T) { require.Equal(t, uint16(tls.VersionTLS11), tlsConfig.MinVersion) require.Equal(t, uint16(tls.VersionTLS12), tlsConfig.MaxVersion) } + +func TestVerifyCert(t *testing.T) { + tests := map[string]struct { + domainName string + certBytes []byte + expectedErrMsg string + }{ + "empty_cert_bytes_no_error": {domainName: "gitlab.io", certBytes: nil, expectedErrMsg: ""}, + "invalid_cert_bytes": {domainName: "gitlab.io", certBytes: []byte(`not PEM bytes`), expectedErrMsg: ErrEmptyCert.Error()}, + "valid_cert": {domainName: "gitlab.io", certBytes: genTestCert(t, "gitlab.io", time.Second)}, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + err := VerifyCert(test.domainName, test.certBytes) + if test.expectedErrMsg != "" { + require.Error(t, err) + require.Contains(t, err.Error(), test.expectedErrMsg) + return + } + + require.NoError(t, err) + }) + } +} + +func genTestCert(t *testing.T, commonName string, expiry time.Duration) []byte { + t.Helper() + + priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + + template := x509.Certificate{ + DNSNames: []string{commonName}, + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Acme Co"}, + CommonName: commonName, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(expiry), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + /* + hosts := strings.Split(*host, ",") + for _, h := range hosts { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } + } + if *isCA { + } + */ + template.IsCA = true + template.KeyUsage |= x509.KeyUsageCertSign + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv) + require.NoError(t, err) + //return derBytes + //if err != nil { + // log.Fatalf("Failed to create certificate: %s", err) + //} + out := &bytes.Buffer{} + pem.Encode(out, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + ////fmt.Println(out.String()) + //out.Reset() + pem.Encode(out, pemBlockForKey(priv)) + //fmt.Println(out.String()) + return out.Bytes() +} + +func publicKey(priv interface{}) interface{} { + switch k := priv.(type) { + case *rsa.PrivateKey: + return &k.PublicKey + case *ecdsa.PrivateKey: + return &k.PublicKey + default: + return nil + } +} + +func pemBlockForKey(priv interface{}) *pem.Block { + switch k := priv.(type) { + case *rsa.PrivateKey: + return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)} + case *ecdsa.PrivateKey: + b, err := x509.MarshalECPrivateKey(k) + if err != nil { + fmt.Fprintf(os.Stderr, "Unable to marshal ECDSA private key: %v", err) + os.Exit(2) + } + return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b} + default: + return nil + } +} diff --git a/internal/config/validate.go b/internal/config/validate.go index dd6fec5d508fe4dee31638448a6537cccfd50e9d..ceb376997f2b3f02b3f39190423b310f61f6dabd 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -5,8 +5,9 @@ import ( "net/url" "github.com/hashicorp/go-multierror" + "gitlab.com/gitlab-org/labkit/log" - "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" + pages_tls "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) func validateConfig(config *Config) error { @@ -18,7 +19,11 @@ func validateConfig(config *Config) error { return err } - return tls.ValidateTLSVersions(*tlsMinVersion, *tlsMaxVersion) + if err := pages_tls.VerifyCert(config.General.Domain, config.General.RootCertificate); err != nil { + log.WithError(err).Warn("invalid root-cert, HTTPS connections may not work") + } + + return pages_tls.ValidateTLSVersions(*tlsMinVersion, *tlsMaxVersion) } func validateAuthConfig(config *Config) error { diff --git a/internal/domain/domain.go b/internal/domain/domain.go index 94888e34640c91ff91f14514092cc0125a841899..caf92def19b9b9aa8410a8323eda73ae7fdf6a2d 100644 --- a/internal/domain/domain.go +++ b/internal/domain/domain.go @@ -7,8 +7,11 @@ import ( "net/http" "sync" + "github.com/hashicorp/go-multierror" "gitlab.com/gitlab-org/labkit/errortracking" + pages_tls "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" + "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/serving" ) @@ -32,6 +35,26 @@ type Domain struct { // New creates a new domain with a resolver and existing certificates func New(name, cert, key string, resolver Resolver) *Domain { + + //block, _ := pem.Decode(config.General.RootCertificate) + //if block == nil { + // return fmt.Errorf("empty root certificate") + //} + // + //cert, err := x509.ParseCertificate(block.Bytes) + //if err != nil { + // panic("failed to parse certificate: " + err.Error()) + //} + // + //opts := x509.VerifyOptions{ + // DNSName: "*." + config.General.Domain, + // Roots: x509.NewCertPool(), + //} + // + //if _, err := cert.Verify(opts); err != nil { + // return err + //} + return &Domain{ Name: name, CertificateCert: cert, @@ -106,6 +129,10 @@ func (d *Domain) EnsureCertificate() (*tls.Certificate, error) { if d == nil || len(d.CertificateKey) == 0 || len(d.CertificateCert) == 0 { return nil, errors.New("tls certificates can be loaded only for pages with configuration") } + var multiError *multierror.Error + if err := pages_tls.VerifyCert(d.Name, []byte(d.CertificateCert)); err != nil { + multiError = multierror.Append(multiError, err) + } d.certificateOnce.Do(func() { var cert tls.Certificate @@ -118,7 +145,9 @@ func (d *Domain) EnsureCertificate() (*tls.Certificate, error) { } }) - return d.certificate, d.certificateError + multiError = multierror.Append(multiError, d.certificateError) + + return d.certificate, multiError.ErrorOrNil() } // ServeFileHTTP returns true if something was served, false if not.