pss.go 4.21 KB
Newer Older
1 2 3 4 5 6 7 8
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package api

import (
	"context"
9
	"crypto/ecdsa"
10 11 12 13 14 15
	"encoding/hex"
	"io/ioutil"
	"net/http"
	"strings"
	"time"

16
	"github.com/ethersphere/bee/pkg/crypto"
17
	"github.com/ethersphere/bee/pkg/jsonhttp"
18
	"github.com/ethersphere/bee/pkg/pss"
19 20 21 22 23 24 25 26 27 28 29
	"github.com/ethersphere/bee/pkg/swarm"
	"github.com/gorilla/mux"
	"github.com/gorilla/websocket"
)

var (
	writeDeadline   = 4 * time.Second // write deadline. should be smaller than the shutdown timeout on api close
	targetMaxLength = 2               // max target length in bytes, in order to prevent grieving by excess computation
)

func (s *server) pssPostHandler(w http.ResponseWriter, r *http.Request) {
30 31
	topicVar := mux.Vars(r)["topic"]
	topic := pss.NewTopic(topicVar)
32

33 34 35
	targetsVar := mux.Vars(r)["targets"]
	var targets pss.Targets
	tgts := strings.Split(targetsVar, ",")
36 37 38 39

	for _, v := range tgts {
		target, err := hex.DecodeString(v)
		if err != nil || len(target) > targetMaxLength {
40 41
			s.logger.Debugf("pss send: bad targets: %v", err)
			s.logger.Error("pss send: bad targets")
42 43 44 45 46 47
			jsonhttp.BadRequest(w, nil)
			return
		}
		targets = append(targets, target)
	}

48 49 50 51 52 53 54 55 56 57
	recipientQueryString := r.URL.Query().Get("recipient")
	var recipient *ecdsa.PublicKey
	if recipientQueryString == "" {
		// use topic-based encryption
		privkey := crypto.Secp256k1PrivateKeyFromBytes(topic[:])
		recipient = &privkey.PublicKey
	} else {
		var err error
		recipient, err = pss.ParseRecipient(recipientQueryString)
		if err != nil {
58 59
			s.logger.Debugf("pss recipient: %v", err)
			s.logger.Error("pss recipient")
60 61 62 63 64
			jsonhttp.BadRequest(w, nil)
			return
		}
	}

65 66
	payload, err := ioutil.ReadAll(r.Body)
	if err != nil {
67 68
		s.logger.Debugf("pss read payload: %v", err)
		s.logger.Error("pss read payload")
69 70 71 72
		jsonhttp.InternalServerError(w, nil)
		return
	}

73
	err = s.pss.Send(r.Context(), topic, payload, recipient, targets)
74
	if err != nil {
75 76
		s.logger.Debugf("pss send payload: %v. topic: %s", err, topicVar)
		s.logger.Error("pss send payload")
77 78 79 80 81 82 83 84
		jsonhttp.InternalServerError(w, nil)
		return
	}

	jsonhttp.OK(w, nil)
}

func (s *server) pssWsHandler(w http.ResponseWriter, r *http.Request) {
acud's avatar
acud committed
85 86 87 88 89 90 91

	upgrader := websocket.Upgrader{
		ReadBufferSize:  swarm.ChunkSize,
		WriteBufferSize: swarm.ChunkSize,
		CheckOrigin:     s.checkOrigin,
	}

92 93
	conn, err := upgrader.Upgrade(w, r, nil)
	if err != nil {
94 95
		s.logger.Debugf("pss ws: upgrade: %v", err)
		s.logger.Error("pss ws: cannot upgrade")
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
		jsonhttp.InternalServerError(w, nil)
		return
	}

	t := mux.Vars(r)["topic"]
	s.wsWg.Add(1)
	go s.pumpWs(conn, t)
}

func (s *server) pumpWs(conn *websocket.Conn, t string) {
	defer s.wsWg.Done()

	var (
		dataC  = make(chan []byte)
		gone   = make(chan struct{})
111
		topic  = pss.NewTopic(t)
112 113 114 115 116 117 118
		ticker = time.NewTicker(s.WsPingPeriod)
		err    error
	)
	defer func() {
		ticker.Stop()
		_ = conn.Close()
	}()
119
	cleanup := s.pss.Register(topic, func(_ context.Context, m []byte) {
120
		dataC <- m
121 122 123 124 125
	})

	defer cleanup()

	conn.SetCloseHandler(func(code int, text string) error {
126
		s.logger.Debugf("pss handler: client gone. code %d message %s", code, text)
127 128 129 130 131 132 133 134 135
		close(gone)
		return nil
	})

	for {
		select {
		case b := <-dataC:
			err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
			if err != nil {
136
				s.logger.Debugf("pss set write deadline: %v", err)
137 138 139 140 141
				return
			}

			err = conn.WriteMessage(websocket.BinaryMessage, b)
			if err != nil {
142
				s.logger.Debugf("pss write to websocket: %v", err)
143 144 145 146 147 148 149
				return
			}

		case <-s.quit:
			// shutdown
			err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
			if err != nil {
150
				s.logger.Debugf("pss set write deadline: %v", err)
151 152 153 154
				return
			}
			err = conn.WriteMessage(websocket.CloseMessage, []byte{})
			if err != nil {
155
				s.logger.Debugf("pss write close message: %v", err)
156 157 158 159 160 161 162 163
			}
			return
		case <-gone:
			// client gone
			return
		case <-ticker.C:
			err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
			if err != nil {
164
				s.logger.Debugf("pss set write deadline: %v", err)
165 166 167 168 169 170 171 172 173
				return
			}
			if err = conn.WriteMessage(websocket.PingMessage, nil); err != nil {
				// error encountered while pinging client. client probably gone
				return
			}
		}
	}
}