package orderbook

import (
	"container/list"
	"encoding/json"

	"github.com/ethereum/go-ethereum/common"
	"github.com/holiman/uint256"
)

// OrderBook 实现标准撮合算法
type OrderBook struct {
	orders map[string]*list.Element // orderID -> *Order (*list.Element.Value.(*Order))

	asks       *OrderSide
	bids       *OrderSide
	BaseToken  string
	QuoteToken string
}

// NewOrderBook 创建订单簿对象
// NewOrderBook 创建并返回一个新的OrderBook实例，包含买（bids）和卖（asks）订单侧
// 以及一个用于存储订单的map。
func NewOrderBook(baseToken, quoteToken string) *OrderBook {
	return &OrderBook{
		orders:     map[string]*list.Element{},
		bids:       NewOrderSide(),
		asks:       NewOrderSide(),
		BaseToken:  baseToken,
		QuoteToken: quoteToken,
	}
}

// PriceLevel 包含深度中的价格和数量信息
type PriceLevel struct {
	Price    *uint256.Int `json:"price"`
	Quantity *uint256.Int `json:"quantity"`
}

// ProcessMarketOrder 以市价立即从订单簿中获取指定数量
// 参数:
//
//	side     - 交易方向 (ob.Sell 或 ob.Buy)
//	quantity - 想要买入或卖出的数量
//	* 使用 uint256.NewInt() 函数创建新的 uint256 数字
//	  更多信息请参考 https://github.com/holiman/uint256
//
// 返回:
//
//	error        - 如果价格小于等于0则返回非nil
//	done         - 如果您的市价单导致其他订单完成，这些订单将被添加到"done"切片中
//	partial      - 如果您的订单已完成但顶部订单未完全完成，则为非nil
//	partialQuantityProcessed - 如果partial订单非nil，此结果包含从partial订单处理的数量
//	quantityLeft - 如果没有足够的订单处理所有数量，则大于零
func (ob *OrderBook) CalculatePriceAfterExecution(side Side, orderId string, quantity *uint256.Int) (price *uint256.Int, err error) {
	price = uint256.NewInt(0)

	var (
		level *OrderQueue
		iter  func(*uint256.Int) *OrderQueue
	)
	if side == Buy {
		level = ob.asks.MinPriceQueue()
		iter = ob.asks.GreaterThan
	} else {
		level = ob.bids.MaxPriceQueue()
		iter = ob.bids.LessThan
	}
	for quantity.Cmp(uint256.NewInt(0)) > 0 && level != nil {
		levelVolume := level.Volume()
		levelPrice := level.Price()
		if quantity.Cmp(levelVolume) >= 0 {
			price = levelPrice
			overflow := false
			quantity, overflow = new(uint256.Int).SubOverflow(quantity, levelVolume)
			if overflow {
				return nil, ErrOverflow
			}
			level = iter(levelPrice)
		} else {
			price = levelPrice
			quantity = uint256.NewInt(0)
		}
	}

	return
}

// ProcessMarketOrder 以市价立即从订单簿中获取指定数量
// 参数:
//
//	side     - 交易方向 (ob.Sell 或 ob.Buy)
//	quantity - 想要买入或卖出的数量
//	* 使用 uint256.NewInt() 函数创建新的 uint256 数字
//	  更多信息请参考 https://github.com/holiman/uint256
//
// 返回:
//
//	error        - 如果价格小于等于0则返回非nil
//	done         - 如果您的市价单导致其他订单完成，这些订单将被添加到"done"切片中
//	partial      - 如果您的订单已完成但顶部订单未完全完成，则为非nil
//	partialQuantityProcessed - 如果partial订单非nil，此结果包含从partial订单处理的数量
//	quantityLeft - 如果没有足够的订单处理所有数量，则大于零
func (ob *OrderBook) ProcessMarketOrder(side Side, quantity *uint256.Int) (done []*Order, partial *Order, partialQuantityProcessed, quantityLeft *uint256.Int, err error) {
	if quantity.Cmp(uint256.NewInt(0)) <= 0 {
		return nil, nil, uint256.NewInt(0), uint256.NewInt(0), ErrInvalidQuantity
	}

	var (
		iter          func() *OrderQueue
		sideToProcess *OrderSide
	)

	if side == Buy {
		iter = ob.asks.MinPriceQueue
		sideToProcess = ob.asks
	} else {
		iter = ob.bids.MaxPriceQueue
		sideToProcess = ob.bids
	}

	for quantity.Cmp(uint256.NewInt(0)) > 0 && sideToProcess.Len() > 0 {
		bestPrice := iter()
		ordersDone, partialDone, partialQty, quantityLeft, err := ob.processQueue(bestPrice, quantity)
		if err != nil {
			return nil, nil, uint256.NewInt(0), uint256.NewInt(0), err
		}
		done = append(done, ordersDone...)
		partial = partialDone
		partialQuantityProcessed = partialQty
		quantity = quantityLeft
	}

	quantityLeft = quantity
	return
}

// ProcessLimitOrder 将新订单放入订单簿
// 参数:
//
//	side     - 交易方向 (ob.Sell 或 ob.Buy)
//	orderID  - 深度中的唯一订单ID
//	creator  - 订单创建者地址
//	quantity - 想要买入或卖出的数量
//	price    - 不高于（或低于）此价格
//	* 使用 uint256.NewInt() 函数创建新的 uint256 数字
//	  更多信息请参考 https://github.com/holiman/uint256
//
// 返回:
//
//	error   - 如果数量（或价格）小于等于0，或给定ID的订单已存在，则返回非nil
//	done    - 如果您的订单导致其他订单完成，这些订单将被添加到"done"切片中
//	          如果您的订单也完成了，它也会被放入此数组
//	partial - 如果您的订单已完成但顶部订单未完全完成，或如果您的订单部分完成
//	          并且剩余数量被放入订单簿 - partial将包含您剩余数量的订单
//	partialQuantityProcessed - 如果partial订单非nil，此结果包含从partial订单处理的数量
func (ob *OrderBook) ProcessLimitOrder(side Side, orderID string, creator common.Address, quantity, price *uint256.Int, nonce uint64) (done []*Order, partial *Order, makerOrder *Order, partialQuantityProcessed, quantityLeft *uint256.Int, err error) {
	if _, ok := ob.orders[orderID]; ok {
		return nil, nil, nil, uint256.NewInt(0), uint256.NewInt(0), ErrOrderExists
	}

	if quantity.Cmp(uint256.NewInt(0)) <= 0 {
		return nil, nil, nil, uint256.NewInt(0), uint256.NewInt(0), ErrInvalidQuantity
	}

	if price.Cmp(uint256.NewInt(0)) <= 0 {
		return nil, nil, nil, uint256.NewInt(0), uint256.NewInt(0), ErrInvalidPrice
	}
	partialQuantityProcessed = uint256.NewInt(0)
	quantityToTrade := quantity
	quantityLeft = quantity
	var (
		sideToProcess *OrderSide
		sideToAdd     *OrderSide
		comparator    func(*uint256.Int) bool
		iter          func() *OrderQueue
		overflow      bool
	)

	if side == Buy {
		sideToAdd = ob.bids
		sideToProcess = ob.asks
		comparator = func(p *uint256.Int) bool { return price.Cmp(p) >= 0 }
		iter = ob.asks.MinPriceQueue
	} else {
		sideToAdd = ob.asks
		sideToProcess = ob.bids
		comparator = func(p *uint256.Int) bool { return price.Cmp(p) <= 0 }
		iter = ob.bids.MaxPriceQueue
	}

	bestPrice := iter()
	for quantityToTrade.Cmp(uint256.NewInt(0)) > 0 && sideToProcess.Len() > 0 && comparator(bestPrice.Price()) {
		ordersDone, partialDone, partialQty, quantityLeft, error := ob.processQueue(bestPrice, quantityToTrade)
		if error != nil {
			return nil, nil, nil, nil, nil, error
		}
		done = append(done, ordersDone...)
		partial = partialDone
		partialQuantityProcessed = partialQty
		quantityToTrade = quantityLeft
		bestPrice = iter()
	}

	if quantityToTrade.Cmp(uint256.NewInt(0)) > 0 {
		o := NewOrder(orderID, creator, side, quantityToTrade, price, nonce)
		if len(done) > 0 {
			partialQuantityProcessed, overflow = new(uint256.Int).SubOverflow(quantity, quantityToTrade)
			if overflow {
				return nil, nil, nil, nil, nil, ErrOverflow
			}
			partial = o
		}
		makerOrder = o
		ob.orders[orderID], err = sideToAdd.Append(o)
		if err != nil {
			return nil, nil, nil, nil, nil, err
		}
	} else {
		totalQuantity := uint256.NewInt(0)
		totalPrice := uint256.NewInt(0)

		for _, order := range done {
			totalQuantity, overflow = new(uint256.Int).AddOverflow(totalQuantity, order.Quantity)
			if overflow {
				return nil, nil, nil, nil, nil, ErrOverflow
			}

			price, overflow := new(uint256.Int).MulOverflow(order.Price, order.Quantity)
			if overflow {
				return nil, nil, nil, nil, nil, ErrOverflow
			}
			totalPrice, overflow = new(uint256.Int).AddOverflow(totalPrice, price)
			if overflow {
				return nil, nil, nil, nil, nil, ErrOverflow
			}
		}

		if partialQuantityProcessed.Cmp(uint256.NewInt(0)) > 0 {
			totalQuantity, overflow = new(uint256.Int).AddOverflow(totalQuantity, partialQuantityProcessed)
			if overflow {
				return nil, nil, nil, nil, nil, ErrOverflow
			}

			price, overflow := new(uint256.Int).MulOverflow(partial.Price, partialQuantityProcessed)
			if overflow {
				return nil, nil, nil, nil, nil, ErrOverflow
			}
			totalPrice, overflow = new(uint256.Int).AddOverflow(totalPrice, price)
			if overflow {
				return nil, nil, nil, nil, nil, ErrOverflow
			}
		}

		orderPrice := new(uint256.Int).Div(totalPrice, totalQuantity)
		done = append(done, NewOrder(orderID, creator, side, quantity, orderPrice, nonce))
	}
	return
}

// partial.quantity = 是更新过的对手的订单
// partialQuantityProcessed = 最后一单成交了多少
func (ob *OrderBook) processQueue(orderQueue *OrderQueue, quantityToTrade *uint256.Int) (done []*Order, partial *Order, partialQuantityProcessed, quantityLeft *uint256.Int, err error) {
	partialQuantityProcessed = uint256.NewInt(0)
	quantityLeft = quantityToTrade

	for orderQueue.Len() > 0 && quantityLeft.Cmp(uint256.NewInt(0)) > 0 {
		headOrderEl := orderQueue.Head()
		headOrder := headOrderEl.Value.(*Order)

		if quantityLeft.Cmp(headOrder.Remaining()) < 0 {
			quantity, overflow := new(uint256.Int).SubOverflow(headOrder.Remaining(), quantityLeft)
			if overflow {
				return nil, nil, uint256.NewInt(0), uint256.NewInt(0), ErrOverflow
			}
			headOrder.Fill(quantity)
			partial = headOrder
			partialQuantityProcessed = quantityLeft
			orderQueue.Update(headOrderEl, partial)
			quantityLeft = uint256.NewInt(0)
		} else {
			overflow := false
			quantityLeft, overflow = new(uint256.Int).SubOverflow(quantityLeft, headOrder.Remaining())
			if overflow {
				return nil, nil, uint256.NewInt(0), uint256.NewInt(0), ErrOverflow
			}
			order, err := ob.CancelOrder(headOrder.Id)
			if err != nil {
				return nil, nil, uint256.NewInt(0), uint256.NewInt(0), err
			}
			order.Fill(headOrder.Remaining())
			done = append(done, order)
		}
	}

	return
}

// Order 通过ID返回订单
func (ob *OrderBook) Order(orderID string) *Order {
	e, ok := ob.orders[orderID]
	if !ok {
		return nil
	}

	return e.Value.(*Order)
}

// Depth 返回价格层级和每个价格层级的数量
func (ob *OrderBook) Depth() (asks, bids []*PriceLevel) {
	level := ob.asks.MaxPriceQueue()
	for level != nil {
		asks = append(asks, &PriceLevel{
			Price:    level.Price(),
			Quantity: level.Volume(),
		})
		level = ob.asks.LessThan(level.Price())
	}

	level = ob.bids.MaxPriceQueue()
	for level != nil {
		bids = append(bids, &PriceLevel{
			Price:    level.Price(),
			Quantity: level.Volume(),
		})
		level = ob.bids.LessThan(level.Price())
	}
	return
}

// CancelOrder 从订单簿中删除给定ID的订单
func (ob *OrderBook) CancelOrder(orderID string) (*Order, error) {
	e, ok := ob.orders[orderID]
	if !ok {
		return nil, nil
	}

	delete(ob.orders, orderID)

	if e.Value.(*Order).Side == Buy {
		return ob.bids.Remove(e)
	}

	return ob.asks.Remove(e)
}

// CalculateMarketPrice 返回请求数量的总市场价格
// 如果err不为nil，price返回该方向所有层级的总价格
func (ob *OrderBook) CalculateMarketPrice(side Side, quantity *uint256.Int) (price *uint256.Int, quant *uint256.Int, err error) {
	price = uint256.NewInt(0)
	quant = uint256.NewInt(0)
	var (
		level *OrderQueue
		iter  func(*uint256.Int) *OrderQueue
	)

	if side == Buy {
		level = ob.asks.MinPriceQueue()
		iter = ob.asks.GreaterThan
	} else {
		level = ob.bids.MaxPriceQueue()
		iter = ob.bids.LessThan
	}

	for quantity.Cmp(uint256.NewInt(0)) > 0 && level != nil {
		levelVolume := level.Volume()
		levelPrice := level.Price()
		if quantity.Cmp(levelVolume) >= 0 {
			overflow := false
			_price, overflow := new(uint256.Int).MulOverflow(levelPrice, levelVolume)
			if overflow {
				return nil, nil, ErrOverflow
			}
			price, overflow = new(uint256.Int).AddOverflow(price, _price)
			if overflow {
				return nil, nil, ErrOverflow
			}

			quantity, overflow = new(uint256.Int).SubOverflow(quantity, levelVolume)
			if overflow {
				return nil, nil, ErrOverflow
			}
			quant, overflow = new(uint256.Int).AddOverflow(quant, levelVolume)
			if overflow {
				return nil, nil, ErrOverflow
			}

			level = iter(levelPrice)
		} else {
			overflow := false
			_price, overflow := new(uint256.Int).MulOverflow(levelPrice, quantity)
			if overflow {
				return nil, nil, ErrOverflow
			}
			price, overflow = new(uint256.Int).AddOverflow(price, _price)
			if overflow {
				return nil, nil, ErrOverflow
			}

			quant, overflow = new(uint256.Int).AddOverflow(quant, quantity)
			if overflow {
				return nil, nil, ErrOverflow
			}
			quantity = uint256.NewInt(0)
		}
	}
	if quantity.Cmp(uint256.NewInt(0)) > 0 {
		err = ErrInsufficientQuantity
	}

	return
}

// String 实现fmt.Stringer接口
func (ob *OrderBook) String() string {
	return ob.asks.String() + "\r\n------------------------------------" + ob.bids.String()
}

// MarshalJSON 实现json.Marshaler接口
func (ob *OrderBook) MarshalJSON() ([]byte, error) {
	return json.Marshal(
		&struct {
			Asks *OrderSide `json:"asks"`
			Bids *OrderSide `json:"bids"`
		}{
			Asks: ob.asks,
			Bids: ob.bids,
		},
	)
}

// UnmarshalJSON 实现json.Unmarshaler接口
func (ob *OrderBook) UnmarshalJSON(data []byte) error {
	obj := struct {
		Asks *OrderSide `json:"asks"`
		Bids *OrderSide `json:"bids"`
	}{}

	if err := json.Unmarshal(data, &obj); err != nil {
		return err
	}

	ob.asks = obj.Asks
	ob.bids = obj.Bids
	ob.orders = map[string]*list.Element{}

	for _, order := range ob.asks.Orders() {
		ob.orders[order.Value.(*Order).Id] = order
	}

	for _, order := range ob.bids.Orders() {
		ob.orders[order.Value.(*Order).Id] = order
	}

	return nil
}
