certman.go 4.23 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
// Package certman provides live reloading of the certificate and key
// files used by the standard library http.Server. It defines a type,
// certMan, with methods watching and getting the files.
// Only valid certificate and key pairs are loaded and an optional
// logger can be passed to certman for logging providing it implements
// the logger interface.
package certman

import (
	"crypto/tls"
	"fmt"
	"path"
	"path/filepath"
	"strings"
	"sync"
	"time"

	"github.com/ethereum/go-ethereum/log"
	"github.com/fsnotify/fsnotify"
)

// A CertMan represents a certificate manager able to watch certificate
// and key pairs for changes.
type CertMan struct {
	mu       sync.RWMutex
	certFile string
	keyFile  string
	keyPair  *tls.Certificate
	watcher  *fsnotify.Watcher
	watching chan bool
	log      log.Logger
}

// New creates a new certMan. The certFile and the keyFile
// are both paths to the location of the files. Relative and
// absolute paths are accepted.
func New(logger log.Logger, certFile, keyFile string) (*CertMan, error) {
	var err error
	certFile, err = filepath.Abs(certFile)
	if err != nil {
		return nil, err
	}
	keyFile, err = filepath.Abs(keyFile)
	if err != nil {
		return nil, err
	}
	cm := &CertMan{
		mu:       sync.RWMutex{},
		certFile: certFile,
		keyFile:  keyFile,
		log:      logger,
	}
	return cm, nil
}

// Watch starts watching for changes to the certificate
// and key files. On any change the certificate and key
// are reloaded. If there is an issue the load will fail
// and the old (if any) certificates and keys will continue
// to be used.
func (cm *CertMan) Watch() error {
	var err error
	if cm.watcher, err = fsnotify.NewWatcher(); err != nil {
		return fmt.Errorf("certman: can't create watcher: %w", err)
	}

	certPath := path.Dir(cm.certFile)
	keyPath := path.Dir(cm.keyFile)

	if err = cm.watcher.Add(certPath); err != nil {
		return fmt.Errorf("certman: can't watch %s: %w", certPath, err)
	}
	if keyPath != certPath {
		if err = cm.watcher.Add(keyPath); err != nil {
			return fmt.Errorf("certman: can't watch %s: %w", certPath, err)
		}
	}
	if err := cm.load(); err != nil {
		cm.log.Error("certman: can't load cert or key file", "err", err)
	}
	cm.log.Info("certman: watching for cert and key change")
	cm.watching = make(chan bool)
	go cm.run()
	return nil
}

func (cm *CertMan) load() error {
	keyPair, err := tls.LoadX509KeyPair(cm.certFile, cm.keyFile)
	if err == nil {
		cm.mu.Lock()
		cm.keyPair = &keyPair
		cm.mu.Unlock()
		cm.log.Info("certman: certificate and key loaded")
	}
	return err
}

func (cm *CertMan) run() {
	cm.log.Info("certman: running")

	ticker := time.NewTicker(2 * time.Second)
	files := []string{cm.certFile, cm.keyFile}
	reload := time.Time{}

loop:
	for {
		select {
		case <-cm.watching:
			cm.log.Info("watching triggered; break loop")
			break loop
		case <-ticker.C:
			if !reload.IsZero() && time.Now().After(reload) {
				reload = time.Time{}
				cm.log.Info("certman: reloading")
				if err := cm.load(); err != nil {
					cm.log.Error("certman: can't load cert or key file", "err", err)
				}
			}
		case event := <-cm.watcher.Events:
			for _, f := range files {
				if event.Name == f ||
					strings.HasSuffix(event.Name, "/..data") { // kubernetes secrets mount
					// we wait a couple seconds in case the cert and key don't update atomically
					cm.log.Info(fmt.Sprintf("%s was modified, queue reload", f))
					reload = time.Now().Add(2 * time.Second)
				}
			}
		case err := <-cm.watcher.Errors:
			cm.log.Error("certman: error watching files", "err", err)
		}
	}
	cm.log.Info("certman: stopped watching")
	cm.watcher.Close()
	ticker.Stop()
}

// GetCertificate returns the loaded certificate for use by
// the GetCertificate field in tls.Config.
func (cm *CertMan) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
	cm.mu.RLock()
	defer cm.mu.RUnlock()
	return cm.keyPair, nil
}

// GetClientCertificate returns the loaded certificate for use by
// the GetClientCertificate field in tls.Config.
func (cm *CertMan) GetClientCertificate(hello *tls.CertificateRequestInfo) (*tls.Certificate, error) {
	cm.mu.RLock()
	defer cm.mu.RUnlock()
	return cm.keyPair, nil
}

// Stop tells certMan to stop watching for changes to the
// certificate and key files.
func (cm *CertMan) Stop() {
	cm.watching <- false
}