Commit f2b4bdc5 authored by Nemanja Zbiljić's avatar Nemanja Zbiljić Committed by GitHub

Respect 'Writer' interface contract in 'ChunkPipe' (#560)

parent d5723357
......@@ -40,20 +40,35 @@ func (c *ChunkPipe) Read(b []byte) (int, error) {
// Writer implements io.Writer
func (c *ChunkPipe) Write(b []byte) (int, error) {
copy(c.data[c.cursor:], b)
c.cursor += len(b)
if c.cursor >= swarm.ChunkSize {
_, err := c.writer.Write(c.data[:swarm.ChunkSize])
if err != nil {
return len(b), err
nw := 0
for {
if nw >= len(b) {
break
}
copied := copy(c.data[c.cursor:], b[nw:])
c.cursor += copied
nw += copied
if c.cursor >= swarm.ChunkSize {
// NOTE: the Write method contract requires all sent data to be
// written before returning (without error)
_, err := c.writer.Write(c.data[:swarm.ChunkSize])
if err != nil {
return nw, err
}
c.cursor -= swarm.ChunkSize
copy(c.data, c.data[swarm.ChunkSize:])
}
c.cursor -= swarm.ChunkSize
copy(c.data, c.data[swarm.ChunkSize:])
}
return len(b), nil
return nw, nil
}
// Closer implements io.Closer
// Close implements io.Closer
func (c *ChunkPipe) Close() error {
if c.cursor > 0 {
_, err := c.writer.Write(c.data[:c.cursor])
......
......@@ -5,8 +5,11 @@
package file_test
import (
"bytes"
"errors"
"fmt"
"io"
"math/rand"
"strconv"
"strings"
"testing"
......@@ -113,3 +116,134 @@ func testChunkPipe(t *testing.T) {
}
}
}
func TestCopyBuffer(t *testing.T) {
readBufferSizes := []int{
64,
1024,
swarm.ChunkSize,
}
dataSizes := []int{
1,
64,
1024,
swarm.ChunkSize - 1,
swarm.ChunkSize,
swarm.ChunkSize + 1,
swarm.ChunkSize * 2,
swarm.ChunkSize*2 + 3,
swarm.ChunkSize * 5,
swarm.ChunkSize*5 + 3,
swarm.ChunkSize * 17,
swarm.ChunkSize*17 + 3,
}
testCases := []struct {
readBufferSize int
dataSize int
}{}
for i := 0; i < len(readBufferSizes); i++ {
for j := 0; j < len(dataSizes); j++ {
testCases = append(testCases, struct {
readBufferSize int
dataSize int
}{readBufferSizes[i], dataSizes[j]})
}
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("buf_%-4d/data_size_%d", tc.readBufferSize, tc.dataSize), func(t *testing.T) {
// https://golang.org/doc/faq#closures_and_goroutines
readBufferSize := tc.readBufferSize
dataSize := tc.dataSize
srcBytes := make([]byte, dataSize)
rand.Read(srcBytes)
chunkPipe := file.NewChunkPipe()
// destination
sizeC := make(chan int)
dataC := make(chan []byte)
go reader(t, readBufferSize, chunkPipe, sizeC, dataC)
// source
errC := make(chan error, 1)
go func() {
src := bytes.NewReader(srcBytes)
buf := make([]byte, swarm.ChunkSize)
c, err := io.CopyBuffer(chunkPipe, src, buf)
if err != nil {
errC <- err
}
if c != int64(dataSize) {
errC <- errors.New("read count mismatch")
}
err = chunkPipe.Close()
if err != nil {
errC <- err
}
close(errC)
}()
// receive the writes
// err may or may not be EOF, depending on whether writes end on
// chunk boundary
expected := dataSize
timer := time.NewTimer(time.Second)
readTotal := 0
readData := []byte{}
for {
select {
case c := <-sizeC:
readTotal += c
if readTotal == expected {
// check received content
if !bytes.Equal(srcBytes, readData) {
t.Fatal("invalid byte content received")
}
return
}
case d := <-dataC:
readData = append(readData, d...)
case err := <-errC:
if err != nil {
if err != io.EOF {
t.Fatal(err)
}
}
case <-timer.C:
t.Fatal("timeout")
}
}
})
}
}
func reader(t *testing.T, bufferSize int, r io.Reader, c chan int, cd chan []byte) {
var buf = make([]byte, bufferSize)
for {
n, err := r.Read(buf)
if err == io.EOF {
c <- 0
break
}
if err != nil {
t.Errorf("read: %v", err)
}
b := make([]byte, n)
copy(b, buf)
cd <- b
c <- n
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment