types.go 2.19 KB
Newer Older
1 2 3 4 5
package database

import (
	"database/sql/driver"
	"errors"
6
	"io"
7 8
	"math/big"

9 10 11
	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/core/types"
	"github.com/ethereum/go-ethereum/rlp"
12 13 14 15 16 17 18 19
	"github.com/jackc/pgtype"
)

var u256BigIntOverflow = new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil)
var big10 = big.NewInt(10)

var ErrU256Overflow = errors.New("number exceeds u256")
var ErrU256ContainsDecimal = errors.New("number contains fractional digits")
20
var ErrU256Null = errors.New("number cannot be null")
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36

// U256 is a wrapper over big.Int that conforms to the database U256 numeric domain type
type U256 struct {
	Int *big.Int
}

// Scan implements the database/sql Scanner interface.
func (u256 *U256) Scan(src interface{}) error {
	// deserialize as a numeric
	var numeric pgtype.Numeric
	err := numeric.Scan(src)
	if err != nil {
		return err
	} else if numeric.Exp < 0 {
		return ErrU256ContainsDecimal
	} else if numeric.Status == pgtype.Null {
37
		return ErrU256Null
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
	}

	// factor in the powers of 10
	num := numeric.Int
	if numeric.Exp > 0 {
		factor := new(big.Int).Exp(big10, big.NewInt(int64(numeric.Exp)), nil)
		num.Mul(num, factor)
	}

	// check bounds before setting the u256
	if num.Cmp(u256BigIntOverflow) >= 0 {
		return ErrU256Overflow
	} else {
		u256.Int = num
	}

	return nil
}

// Value implements the database/sql/driver Valuer interface.
func (u256 U256) Value() (driver.Value, error) {
	// check bounds
	if u256.Int == nil {
61
		return nil, ErrU256Null
62 63 64 65 66 67 68 69
	} else if u256.Int.Cmp(u256BigIntOverflow) >= 0 {
		return nil, ErrU256Overflow
	}

	// simply encode as a numeric with no Exp set (non-decimal)
	numeric := pgtype.Numeric{Int: u256.Int, Status: pgtype.Present}
	return numeric.Value()
}
70

71
type RLPHeader types.Header
72

73
func (h *RLPHeader) EncodeRLP(w io.Writer) error {
74 75 76
	return types.NewBlockWithHeader((*types.Header)(h)).EncodeRLP(w)
}

77
func (h *RLPHeader) DecodeRLP(s *rlp.Stream) error {
78 79 80 81 82 83 84
	block := new(types.Block)
	err := block.DecodeRLP(s)
	if err != nil {
		return err
	}

	header := block.Header()
85
	*h = (RLPHeader)(*header)
86 87 88
	return nil
}

89
func (h *RLPHeader) Header() *types.Header {
90 91 92
	return (*types.Header)(h)
}

93
func (h *RLPHeader) Hash() common.Hash {
94 95
	return h.Header().Hash()
}