Commit 164063e9 authored by Petar Radovic's avatar Petar Radovic

handshake mock refactor

parent 72217321
......@@ -7,59 +7,16 @@ import (
"bytes"
"errors"
"fmt"
"github.com/ethersphere/bee/pkg/swarm"
"io/ioutil"
"testing"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/mock"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/pb"
"github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/swarm"
)
type StreamMock struct {
readBuffer *bytes.Buffer
writeBuffer *bytes.Buffer
writeCounter int
readCounter int
readError error
writeError error
readErrCheckmark int
writeErrCheckmark int
}
func (s *StreamMock) setReadErr(err error, checkmark int) {
s.readError = err
s.readErrCheckmark = checkmark
}
func (s *StreamMock) setWriteErr(err error, checkmark int) {
s.writeError = err
s.writeErrCheckmark = checkmark
}
func (s *StreamMock) Read(p []byte) (n int, err error) {
if s.readError != nil && s.readErrCheckmark <= s.readCounter {
return 0, s.readError
}
s.readCounter++
return s.readBuffer.Read(p)
}
func (s *StreamMock) Write(p []byte) (n int, err error) {
if s.writeError != nil && s.writeErrCheckmark <= s.writeCounter {
return 0, s.writeError
}
s.writeCounter++
return s.writeBuffer.Write(p)
}
func (s *StreamMock) Close() error {
return nil
}
func TestHandshake(t *testing.T) {
node1Addr := swarm.MustParseHexAddress("ca1e9f3938cc1425c6061b96ad9eb93e134dfe8734ad490164ef20af9d1cf59c")
......@@ -81,8 +38,8 @@ func TestHandshake(t *testing.T) {
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := &StreamMock{readBuffer: &buffer1, writeBuffer: &buffer2}
stream2 := &StreamMock{readBuffer: &buffer2, writeBuffer: &buffer1}
stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1)
w, r := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.ShakeHandAck{
......@@ -111,8 +68,8 @@ func TestHandshake(t *testing.T) {
t.Run("ERROR - shakehand write error ", func(t *testing.T) {
testErr := errors.New("test error")
expectedErr := fmt.Errorf("write message: %w", testErr)
stream := &StreamMock{}
stream.setWriteErr(testErr, 0)
stream := &mock.StreamMock{}
stream.SetWriteErr(testErr, 0)
res, err := handshakeService.Handshake(stream)
if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err)
......@@ -126,8 +83,8 @@ func TestHandshake(t *testing.T) {
t.Run("ERROR - shakehand read error ", func(t *testing.T) {
testErr := errors.New("test error")
expectedErr := fmt.Errorf("read message: %w", testErr)
stream := &StreamMock{writeBuffer: &bytes.Buffer{}}
stream.setReadErr(testErr, 0)
stream := mock.NewStream(nil, &bytes.Buffer{})
stream.SetReadErr(testErr, 0)
res, err := handshakeService.Handshake(stream)
if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err)
......@@ -149,9 +106,9 @@ func TestHandshake(t *testing.T) {
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := &StreamMock{readBuffer: &buffer1, writeBuffer: &buffer2}
stream1.setWriteErr(testErr, 1)
stream2 := &StreamMock{readBuffer: &buffer2, writeBuffer: &buffer1}
stream1 := mock.NewStream(&buffer1, &buffer2)
stream1.SetWriteErr(testErr, 1)
stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.ShakeHandAck{
......@@ -197,8 +154,8 @@ func TestHandle(t *testing.T) {
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := &StreamMock{readBuffer: &buffer1, writeBuffer: &buffer2}
stream2 := &StreamMock{readBuffer: &buffer2, writeBuffer: &buffer1}
stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.ShakeHand{
......@@ -236,8 +193,8 @@ func TestHandle(t *testing.T) {
t.Run("ERROR - read error ", func(t *testing.T) {
testErr := errors.New("test error")
expectedErr := fmt.Errorf("read message: %w", testErr)
stream := &StreamMock{}
stream.setReadErr(testErr, 0)
stream := &mock.StreamMock{}
stream.SetReadErr(testErr, 0)
res, err := handshakeService.Handle(stream)
if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err)
......@@ -252,8 +209,8 @@ func TestHandle(t *testing.T) {
testErr := errors.New("test error")
expectedErr := fmt.Errorf("write message: %w", testErr)
var buffer bytes.Buffer
stream := &StreamMock{readBuffer: &buffer, writeBuffer: &buffer}
stream.setWriteErr(testErr, 1)
stream := mock.NewStream(&buffer, &buffer)
stream.SetWriteErr(testErr, 1)
w, _ := protobuf.NewWriterAndReader(stream)
if err := w.WriteMsg(&pb.ShakeHand{
Address: node1Addr.Bytes(),
......@@ -284,9 +241,9 @@ func TestHandle(t *testing.T) {
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := &StreamMock{readBuffer: &buffer1, writeBuffer: &buffer2}
stream2 := &StreamMock{readBuffer: &buffer2, writeBuffer: &buffer1}
stream1.setReadErr(testErr, 1)
stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1)
stream1.SetReadErr(testErr, 1)
w, _ := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.ShakeHand{
Address: node2Info.Address.Bytes(),
......
package mock
import "bytes"
type StreamMock struct {
readBuffer *bytes.Buffer
writeBuffer *bytes.Buffer
writeCounter int
readCounter int
readError error
writeError error
readErrCheckmark int
writeErrCheckmark int
}
func NewStream(readBuffer, writeBuffer *bytes.Buffer) *StreamMock {
return &StreamMock{readBuffer: readBuffer, writeBuffer: writeBuffer}
}
func (s *StreamMock) SetReadErr(err error, checkmark int) {
s.readError = err
s.readErrCheckmark = checkmark
}
func (s *StreamMock) SetWriteErr(err error, checkmark int) {
s.writeError = err
s.writeErrCheckmark = checkmark
}
func (s *StreamMock) Read(p []byte) (n int, err error) {
if s.readError != nil && s.readErrCheckmark <= s.readCounter {
return 0, s.readError
}
s.readCounter++
return s.readBuffer.Read(p)
}
func (s *StreamMock) Write(p []byte) (n int, err error) {
if s.writeError != nil && s.writeErrCheckmark <= s.writeCounter {
return 0, s.writeError
}
s.writeCounter++
return s.writeBuffer.Write(p)
}
func (s *StreamMock) Close() error {
return nil
}
......@@ -6,8 +6,9 @@ package mock
import (
"context"
"github.com/ethersphere/bee/pkg/swarm"
"time"
"github.com/ethersphere/bee/pkg/swarm"
)
type Service struct {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment