package sequencer

import (
	"bufio"
	"bytes"
	"compress/zlib"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"math"

	l2types "github.com/ethereum-optimism/optimism/l2geth/core/types"
	l2rlp "github.com/ethereum-optimism/optimism/l2geth/rlp"
)

const (
	// TxLenSize is the number of bytes used to represent the size of a
	// serialized sequencer transaction.
	TxLenSize = 3
)

var (
	// byteOrder represents the endiannes used for batch serialization
	byteOrder = binary.BigEndian
	// ErrMalformedBatch represents a batch that is not well formed
	// according to the protocol specification
	ErrMalformedBatch = errors.New("malformed batch")
)

// BatchContext denotes a range of transactions that belong the same batch. It
// is used to compress shared fields that would otherwise be repeated for each
// transaction.
type BatchContext struct {
	// NumSequencedTxs specifies the number of sequencer txs included in
	// the batch.
	NumSequencedTxs uint64 `json:"num_sequenced_txs"`

	// NumSubsequentQueueTxs specifies the number of queued txs included in
	// the batch
	NumSubsequentQueueTxs uint64 `json:"num_subsequent_queue_txs"`

	// Timestamp is the L1 timestamp of the batch.
	Timestamp uint64 `json:"timestamp"`

	// BlockNumber is the L1 BlockNumber of the batch.
	BlockNumber uint64 `json:"block_number"`
}

// Write encodes the BatchContext into a 16-byte stream using the following
// encoding:
//  - num_sequenced_txs:        3 bytes
//  - num_subsequent_queue_txs: 3 bytes
//  - timestamp:                5 bytes
//  - block_number:             5 bytes
//
// Note that writing to a bytes.Buffer cannot
// error, so errors are ignored here
func (c *BatchContext) Write(w *bytes.Buffer) {
	_ = writeUint64(w, c.NumSequencedTxs, 3)
	_ = writeUint64(w, c.NumSubsequentQueueTxs, 3)
	_ = writeUint64(w, c.Timestamp, 5)
	_ = writeUint64(w, c.BlockNumber, 5)
}

// Read decodes the BatchContext from the passed reader. If fewer than 16-bytes
// remain, an error is returned. Otherwise the first 16-bytes will be read using
// the expected encoding:
//  - num_sequenced_txs:        3 bytes
//  - num_subsequent_queue_txs: 3 bytes
//  - timestamp:                5 bytes
//  - block_number:             5 bytes
func (c *BatchContext) Read(r io.Reader) error {
	if err := readUint64(r, &c.NumSequencedTxs, 3); err != nil {
		return err
	}
	if err := readUint64(r, &c.NumSubsequentQueueTxs, 3); err != nil {
		return err
	}
	if err := readUint64(r, &c.Timestamp, 5); err != nil {
		return err
	}
	return readUint64(r, &c.BlockNumber, 5)
}

// BatchType represents the type of batch being
// submitted. When the first context in the batch
// has a timestamp of 0, the blocknumber is interpreted
// as an enum that represets the type
type BatchType int8

// Implements the Stringer interface for BatchType
func (b BatchType) String() string {
	switch b {
	case BatchTypeLegacy:
		return "LEGACY"
	case BatchTypeZlib:
		return "ZLIB"
	default:
		return ""
	}
}

// BatchTypeFromString returns the BatchType
// enum based on a human readable string
func BatchTypeFromString(s string) BatchType {
	switch s {
	case "zlib", "ZLIB":
		return BatchTypeZlib
	case "legacy", "LEGACY":
		return BatchTypeLegacy
	default:
		return BatchTypeLegacy
	}
}

const (
	// BatchTypeLegacy represets the legacy batch type
	BatchTypeLegacy BatchType = -1
	// BatchTypeZlib represents a batch type where the
	// transaction data is compressed using zlib
	BatchTypeZlib BatchType = 0
)

// AppendSequencerBatchParams holds the raw data required to submit a batch of
// L2 txs to L1 CTC contract. Rather than encoding the objects using the
// standard ABI encoding, a custom encoding is and provided in the call data to
// optimize for gas fees, since batch submission of L2 txs is a primary cost
// driver.
type AppendSequencerBatchParams struct {
	// ShouldStartAtElement specifies the intended starting sequence number
	// of the provided transaction. Upon submission, this should match the
	// CTC's expected value otherwise the transaction will revert.
	ShouldStartAtElement uint64

	// TotalElementsToAppend indicates the number of L2 txs represented by
	// this batch. This includes both sequencer and queued txs.
	TotalElementsToAppend uint64

	// Contexts aggregates redundant L1 block numbers and L1 timestamps for
	// the txns encoded in the Tx slice. Further, they specify consecutive
	// tx windows in Txs and implicitly allow one to compute how many
	// (ommitted) queued txs are in a given window.
	Contexts []BatchContext

	// Txs contains all sequencer txs that will be recorded in the L1 CTC
	// contract.
	Txs []*CachedTx

	// The type of the batch
	Type BatchType
}

// Write encodes the AppendSequencerBatchParams using the following format:
//  - should_start_at_element:        5 bytes
//  - total_elements_to_append:       3 bytes
//  - num_contexts:                   3 bytes
//    - num_contexts * batch_context: num_contexts * 16 bytes
//  - [num txs ommitted]
//    - tx_len:                       3 bytes
//    - tx_bytes:                     tx_len bytes
//
// Typed batches include a dummy context as the first context
// where the timestamp is 0. The blocknumber is interpreted
// as an enum that defines the type. It is impossible to have
// a timestamp of 0 in practice, so this safely can indicate
// that the batch is typed.
// Type 0 batches have a dummy context where the blocknumber is
// set to 0. The transaction data is compressed with zlib before
// submitting the transaction to the chain. The fields should_start_at_element,
// total_elements_to_append, num_contexts and the contexts themselves
// are not altered.
//
// Note that writing to a bytes.Buffer cannot
// error, so errors are ignored here
func (p *AppendSequencerBatchParams) Write(w *bytes.Buffer) error {
	_ = writeUint64(w, p.ShouldStartAtElement, 5)
	_ = writeUint64(w, p.TotalElementsToAppend, 3)

	// There must be contexts if there are transactions
	if len(p.Contexts) == 0 && len(p.Txs) != 0 {
		return ErrMalformedBatch
	}

	// There must be transactions if there are contexts
	if len(p.Txs) == 0 && len(p.Contexts) != 0 {
		return ErrMalformedBatch
	}

	// copy the contexts as to not malleate the struct
	// when it is a typed batch
	contexts := make([]BatchContext, 0, len(p.Contexts)+1)
	if p.Type == BatchTypeZlib {
		// All zero values for the single batch context
		// is desired here as blocknumber 0 means it is a zlib batch
		contexts = append(contexts, BatchContext{})
	}
	contexts = append(contexts, p.Contexts...)

	// Write number of contexts followed by each fixed-size BatchContext.
	_ = writeUint64(w, uint64(len(contexts)), 3)
	for _, context := range contexts {
		context.Write(w)
	}

	switch p.Type {
	case BatchTypeLegacy:
		// Write each length-prefixed tx.
		for _, tx := range p.Txs {
			_ = writeUint64(w, uint64(tx.Size()), TxLenSize)
			_, _ = w.Write(tx.RawTx()) // can't fail for bytes.Buffer
		}
	case BatchTypeZlib:
		zw := zlib.NewWriter(w)
		for _, tx := range p.Txs {
			if err := writeUint64(zw, uint64(tx.Size()), TxLenSize); err != nil {
				return err
			}
			if _, err := zw.Write(tx.RawTx()); err != nil {
				return err
			}
		}
		if err := zw.Close(); err != nil {
			return err
		}

	default:
		return fmt.Errorf("Unknown batch type: %s", p.Type)
	}

	return nil
}

// Serialize performs the same encoding as Write, but returns the resulting
// bytes slice.
func (p *AppendSequencerBatchParams) Serialize() ([]byte, error) {
	var buf bytes.Buffer
	if err := p.Write(&buf); err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}

// Read decodes the AppendSequencerBatchParams from a bytes stream. If the byte
// stream does not terminate cleanly with an EOF while reading a tx_len, this
// method will return an error. Otherwise, the stream will be parsed according
// to the following format:
//  - should_start_at_element:        5 bytes
//  - total_elements_to_append:       3 bytes
//  - num_contexts:                   3 bytes
//    - num_contexts * batch_context: num_contexts * 16 bytes
//  - [num txs ommitted]
//    - tx_len:                       3 bytes
//    - tx_bytes:                     tx_len bytes
func (p *AppendSequencerBatchParams) Read(r io.Reader) error {
	if err := readUint64(r, &p.ShouldStartAtElement, 5); err != nil {
		return err
	}
	if err := readUint64(r, &p.TotalElementsToAppend, 3); err != nil {
		return err
	}

	// Read number of contexts and deserialize each one.
	var numContexts uint64
	if err := readUint64(r, &numContexts, 3); err != nil {
		return err
	}

	// Ensure that contexts is never nil
	p.Contexts = make([]BatchContext, 0)
	for i := uint64(0); i < numContexts; i++ {
		var batchContext BatchContext
		if err := batchContext.Read(r); err != nil {
			return err
		}

		p.Contexts = append(p.Contexts, batchContext)
	}

	// Assume that it is a legacy batch at first
	p.Type = BatchTypeLegacy

	// Handle backwards compatible batch types
	if len(p.Contexts) > 0 && p.Contexts[0].Timestamp == 0 {
		switch p.Contexts[0].BlockNumber {
		case 0:
			// zlib compressed transaction data
			p.Type = BatchTypeZlib
			// remove the first dummy context
			p.Contexts = p.Contexts[1:]
			numContexts--

			zr, err := zlib.NewReader(r)
			if err != nil {
				return err
			}
			defer zr.Close()

			r = bufio.NewReader(zr)
		}
	}

	// Deserialize any transactions. Since the number of txs is ommitted
	// from the encoding, loop until the stream is consumed.
	for {
		var txLen uint64
		err := readUint64(r, &txLen, TxLenSize)
		// Getting an EOF when reading the txLen expected for a cleanly
		// encoded object. Silence the error and return success if
		// the batch is well formed.
		if err == io.EOF {
			if len(p.Contexts) == 0 && len(p.Txs) != 0 {
				return ErrMalformedBatch
			}
			if len(p.Txs) == 0 && len(p.Contexts) != 0 {
				return ErrMalformedBatch
			}
			return nil
		} else if err != nil {
			return err
		}

		tx := new(l2types.Transaction)
		if err := tx.DecodeRLP(l2rlp.NewStream(r, txLen)); err != nil {
			return err
		}

		p.Txs = append(p.Txs, NewCachedTx(tx))
	}

}

// writeUint64 writes a the bottom `n` bytes of `val` to `w`.
func writeUint64(w io.Writer, val uint64, n uint) error {
	if n < 1 || n > 8 {
		panic(fmt.Sprintf("invalid number of bytes %d must be 1-8", n))
	}

	const maxUint64 uint64 = math.MaxUint64
	maxVal := maxUint64 >> (8 * (8 - n))
	if val > maxVal {
		panic(fmt.Sprintf("cannot encode %d in %d byte value", val, n))
	}

	var buf [8]byte
	byteOrder.PutUint64(buf[:], val)
	_, err := w.Write(buf[8-n:])
	return err
}

// readUint64 reads `n` bytes from `r` and returns them in the lower `n` bytes
// of `val`.
func readUint64(r io.Reader, val *uint64, n uint) error {
	var buf [8]byte
	if _, err := r.Read(buf[8-n:]); err != nil {
		return err
	}
	*val = byteOrder.Uint64(buf[:])
	return nil
}