Commit 164063e9 authored by Petar Radovic's avatar Petar Radovic

handshake mock refactor

parent 72217321
...@@ -7,59 +7,16 @@ import ( ...@@ -7,59 +7,16 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/swarm"
"io/ioutil" "io/ioutil"
"testing" "testing"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/mock"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/pb" "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/pb"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/swarm"
) )
type StreamMock struct {
readBuffer *bytes.Buffer
writeBuffer *bytes.Buffer
writeCounter int
readCounter int
readError error
writeError error
readErrCheckmark int
writeErrCheckmark int
}
func (s *StreamMock) setReadErr(err error, checkmark int) {
s.readError = err
s.readErrCheckmark = checkmark
}
func (s *StreamMock) setWriteErr(err error, checkmark int) {
s.writeError = err
s.writeErrCheckmark = checkmark
}
func (s *StreamMock) Read(p []byte) (n int, err error) {
if s.readError != nil && s.readErrCheckmark <= s.readCounter {
return 0, s.readError
}
s.readCounter++
return s.readBuffer.Read(p)
}
func (s *StreamMock) Write(p []byte) (n int, err error) {
if s.writeError != nil && s.writeErrCheckmark <= s.writeCounter {
return 0, s.writeError
}
s.writeCounter++
return s.writeBuffer.Write(p)
}
func (s *StreamMock) Close() error {
return nil
}
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
node1Addr := swarm.MustParseHexAddress("ca1e9f3938cc1425c6061b96ad9eb93e134dfe8734ad490164ef20af9d1cf59c") node1Addr := swarm.MustParseHexAddress("ca1e9f3938cc1425c6061b96ad9eb93e134dfe8734ad490164ef20af9d1cf59c")
...@@ -81,8 +38,8 @@ func TestHandshake(t *testing.T) { ...@@ -81,8 +38,8 @@ func TestHandshake(t *testing.T) {
var buffer1 bytes.Buffer var buffer1 bytes.Buffer
var buffer2 bytes.Buffer var buffer2 bytes.Buffer
stream1 := &StreamMock{readBuffer: &buffer1, writeBuffer: &buffer2} stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := &StreamMock{readBuffer: &buffer2, writeBuffer: &buffer1} stream2 := mock.NewStream(&buffer2, &buffer1)
w, r := protobuf.NewWriterAndReader(stream2) w, r := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.ShakeHandAck{ if err := w.WriteMsg(&pb.ShakeHandAck{
...@@ -111,8 +68,8 @@ func TestHandshake(t *testing.T) { ...@@ -111,8 +68,8 @@ func TestHandshake(t *testing.T) {
t.Run("ERROR - shakehand write error ", func(t *testing.T) { t.Run("ERROR - shakehand write error ", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("write message: %w", testErr) expectedErr := fmt.Errorf("write message: %w", testErr)
stream := &StreamMock{} stream := &mock.StreamMock{}
stream.setWriteErr(testErr, 0) stream.SetWriteErr(testErr, 0)
res, err := handshakeService.Handshake(stream) res, err := handshakeService.Handshake(stream)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
...@@ -126,8 +83,8 @@ func TestHandshake(t *testing.T) { ...@@ -126,8 +83,8 @@ func TestHandshake(t *testing.T) {
t.Run("ERROR - shakehand read error ", func(t *testing.T) { t.Run("ERROR - shakehand read error ", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("read message: %w", testErr) expectedErr := fmt.Errorf("read message: %w", testErr)
stream := &StreamMock{writeBuffer: &bytes.Buffer{}} stream := mock.NewStream(nil, &bytes.Buffer{})
stream.setReadErr(testErr, 0) stream.SetReadErr(testErr, 0)
res, err := handshakeService.Handshake(stream) res, err := handshakeService.Handshake(stream)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
...@@ -149,9 +106,9 @@ func TestHandshake(t *testing.T) { ...@@ -149,9 +106,9 @@ func TestHandshake(t *testing.T) {
var buffer1 bytes.Buffer var buffer1 bytes.Buffer
var buffer2 bytes.Buffer var buffer2 bytes.Buffer
stream1 := &StreamMock{readBuffer: &buffer1, writeBuffer: &buffer2} stream1 := mock.NewStream(&buffer1, &buffer2)
stream1.setWriteErr(testErr, 1) stream1.SetWriteErr(testErr, 1)
stream2 := &StreamMock{readBuffer: &buffer2, writeBuffer: &buffer1} stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2) w, _ := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.ShakeHandAck{ if err := w.WriteMsg(&pb.ShakeHandAck{
...@@ -197,8 +154,8 @@ func TestHandle(t *testing.T) { ...@@ -197,8 +154,8 @@ func TestHandle(t *testing.T) {
var buffer1 bytes.Buffer var buffer1 bytes.Buffer
var buffer2 bytes.Buffer var buffer2 bytes.Buffer
stream1 := &StreamMock{readBuffer: &buffer1, writeBuffer: &buffer2} stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := &StreamMock{readBuffer: &buffer2, writeBuffer: &buffer1} stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2) w, _ := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.ShakeHand{ if err := w.WriteMsg(&pb.ShakeHand{
...@@ -236,8 +193,8 @@ func TestHandle(t *testing.T) { ...@@ -236,8 +193,8 @@ func TestHandle(t *testing.T) {
t.Run("ERROR - read error ", func(t *testing.T) { t.Run("ERROR - read error ", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("read message: %w", testErr) expectedErr := fmt.Errorf("read message: %w", testErr)
stream := &StreamMock{} stream := &mock.StreamMock{}
stream.setReadErr(testErr, 0) stream.SetReadErr(testErr, 0)
res, err := handshakeService.Handle(stream) res, err := handshakeService.Handle(stream)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
...@@ -252,8 +209,8 @@ func TestHandle(t *testing.T) { ...@@ -252,8 +209,8 @@ func TestHandle(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("write message: %w", testErr) expectedErr := fmt.Errorf("write message: %w", testErr)
var buffer bytes.Buffer var buffer bytes.Buffer
stream := &StreamMock{readBuffer: &buffer, writeBuffer: &buffer} stream := mock.NewStream(&buffer, &buffer)
stream.setWriteErr(testErr, 1) stream.SetWriteErr(testErr, 1)
w, _ := protobuf.NewWriterAndReader(stream) w, _ := protobuf.NewWriterAndReader(stream)
if err := w.WriteMsg(&pb.ShakeHand{ if err := w.WriteMsg(&pb.ShakeHand{
Address: node1Addr.Bytes(), Address: node1Addr.Bytes(),
...@@ -284,9 +241,9 @@ func TestHandle(t *testing.T) { ...@@ -284,9 +241,9 @@ func TestHandle(t *testing.T) {
var buffer1 bytes.Buffer var buffer1 bytes.Buffer
var buffer2 bytes.Buffer var buffer2 bytes.Buffer
stream1 := &StreamMock{readBuffer: &buffer1, writeBuffer: &buffer2} stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := &StreamMock{readBuffer: &buffer2, writeBuffer: &buffer1} stream2 := mock.NewStream(&buffer2, &buffer1)
stream1.setReadErr(testErr, 1) stream1.SetReadErr(testErr, 1)
w, _ := protobuf.NewWriterAndReader(stream2) w, _ := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.ShakeHand{ if err := w.WriteMsg(&pb.ShakeHand{
Address: node2Info.Address.Bytes(), Address: node2Info.Address.Bytes(),
......
package mock
import "bytes"
type StreamMock struct {
readBuffer *bytes.Buffer
writeBuffer *bytes.Buffer
writeCounter int
readCounter int
readError error
writeError error
readErrCheckmark int
writeErrCheckmark int
}
func NewStream(readBuffer, writeBuffer *bytes.Buffer) *StreamMock {
return &StreamMock{readBuffer: readBuffer, writeBuffer: writeBuffer}
}
func (s *StreamMock) SetReadErr(err error, checkmark int) {
s.readError = err
s.readErrCheckmark = checkmark
}
func (s *StreamMock) SetWriteErr(err error, checkmark int) {
s.writeError = err
s.writeErrCheckmark = checkmark
}
func (s *StreamMock) Read(p []byte) (n int, err error) {
if s.readError != nil && s.readErrCheckmark <= s.readCounter {
return 0, s.readError
}
s.readCounter++
return s.readBuffer.Read(p)
}
func (s *StreamMock) Write(p []byte) (n int, err error) {
if s.writeError != nil && s.writeErrCheckmark <= s.writeCounter {
return 0, s.writeError
}
s.writeCounter++
return s.writeBuffer.Write(p)
}
func (s *StreamMock) Close() error {
return nil
}
...@@ -6,8 +6,9 @@ package mock ...@@ -6,8 +6,9 @@ package mock
import ( import (
"context" "context"
"github.com/ethersphere/bee/pkg/swarm"
"time" "time"
"github.com/ethersphere/bee/pkg/swarm"
) )
type Service struct { type Service struct {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment