// +build opencl

package opencl

/*
#cgo CFLAGS: -I.
#cgo LDFLAGS: -L . -loclsp -lOpenCL
#include "oclsp.h"
#include <string.h>
*/
import "C"

import (
	"errors"
	"github.com/CaduceusMetaverseProtocol/MetaCryptor/xecc/engine"
	. "github.com/CaduceusMetaverseProtocol/MetaCryptor/xecc/types"
	_ "github.com/CaduceusMetaverseProtocol/MetaCryptor/xecc/engine/opencl/statik"
	"github.com/rakyll/statik/fs"
	"io/ioutil"
	"sync"
	"time"
	"unsafe"
)

const (
	clfsname   = "/k.cl"
	verifyOk   = 1
	cl_success = 1

	msgLength    = 32
	rsignLength  = 65
	signLength   = 64
	pubkeyLength = 65

	openclCPU = 1
	openclGPU = 0
)

type OclMsg [msgLength]byte
type OclRecoverableSignature [rsignLength]byte
type OclSignature [signLength]byte
type OclPubkey [pubkeyLength]byte

var (
	ocle = &OclEngine{}

	ErrNotAlready         = errors.New("OpenCL not already")
	ErrLengthNotMatch     = errors.New("param length not matched")
	ErrCallFunctionFailed = errors.New("call function failed")
)

type OclEngine struct {
	init    sync.Once
	mux     sync.Mutex
	already bool
	closed  chan struct{}

	verifytask  *basetask
	recovertask *basetask
}

func GetInstance() *OclEngine {
	ocle.mux.Lock()
	defer ocle.mux.Unlock()
	ocle.init.Do(func() {
		code, e := getCLSourcode()
		if e == nil {
			go func() {
				var codelen C.int = C.int(len(code))
				var c_code *C.uchar = (*C.uchar)(unsafe.Pointer(&code[0]))
				ret := C.secp256_ocl_init(c_code, codelen, openclGPU)
				if ret == 0 {
					ocle.already = true
					ocle.closed = make(chan struct{})
					ocle.recovertask = newBaseTask()
					ocle.verifytask = newBaseTask()
					go ocle.routine()
				}
			}()
		}
	})
	return ocle
}

func (ocl *OclEngine) Ready() bool {
	return ocl.already
}

func (ocl *OclEngine) Support(xTask XTask) bool {
	id := xTask.TaskId()
	if id == FeatureSecp256k1RecoveryPubkey || id == FeatureSecp256k1Verify {
		return true
	}
	return false
}

func (ocl *OclEngine) Name() string {
	return "opencl-engine"
}

func (ocl *OclEngine) doTask(task XTask) (XTask, error) {
	var e error

	switch t := (task).(type) {
	case *XTaskSecp256k1RPubkey:
		t.Pubkey, e = ocl.OclSecp256RecoverPubkeyS(t.Msg, t.Rsig)
		return t, nil
	case *XTaskSecp256k1Verify:
		t.Verify, e = ocl.OclSecp256VerifyS(t.Msg, t.Sig, t.Pubkey)
		return t, nil
	default:
		e = engine.ErrUnsupport
	}
	return nil, e
}

func (ocl *OclEngine) doBatchTask(tasks []XTask) error {
	var e error
	var num = len(tasks)
	switch tasks[0].(type) {
	case *XTaskSecp256k1Verify:
		var batch = make([]*XTaskSecp256k1Verify, len(tasks))
		for i := 0; i < len(tasks); i++ {
			batch[i] = tasks[i].(*XTaskSecp256k1Verify)
		}
		var (
			batchMsg = make([]OclMsg, num)
			batchSig = make([]OclSignature, num)
			batchPub = make([]OclPubkey, num)
			batchRet = make([]C.int, num)
		)
		for i := 0; i < num; i++ {
			copy(batchMsg[i][:], batch[i].Msg)
			copy(batchSig[i][:], batch[i].Sig)
			copy(batchPub[i][:], batch[i].Pubkey)
		}

		e = ocl.batchOclSecp256Verify(batchMsg, batchSig, batchPub, batchRet)
		if e != nil {
			return e
		}

		for i := 0; i < num; i++ {
			batch[i].Verify = (batchRet[i] == 1)
		}

		return nil
	case *XTaskSecp256k1RPubkey:
		var batch = make([]*XTaskSecp256k1RPubkey, len(tasks))
		for i := 0; i < len(tasks); i++ {
			batch[i] = tasks[i].(*XTaskSecp256k1RPubkey)
		}
		var (
			batchMsg  = make([]OclMsg, num)
			batchRsig = make([]OclRecoverableSignature, num)
			batchPub  = make([]OclPubkey, num)
		)
		for i := 0; i < num; i++ {
			copy(batchMsg[i][:], batch[i].Msg)
			copy(batchRsig[i][:], batch[i].Rsig)
		}

		e = ocl.batchOclSecp256RecoverPubkey(batchMsg, batchRsig, batchPub)
		if e != nil {
			return e
		}

		for i := 0; i < num; i++ {
			batch[i].Pubkey = make([]byte, pubkeyLength)
			copy(batch[i].Pubkey, batchPub[i][:])
		}

		return nil
	default:
		e = engine.ErrUnsupport
	}
	return e
}

func (ocl *OclEngine) Process(task XTask) (XTask, error) {
	if !ocl.Ready() {
		return task, ErrNotAlready
	} else {
		t, e := ocl.doTask(task)
		return t, e
	}
}

func (ocl *OclEngine) ProcessA(twp *TaskWithReport) error {
	if !ocl.Ready() {
		return ErrNotAlready
	}
	ocl.mux.Lock()
	defer ocl.mux.Unlock()

	switch twp.XTask.(type) {
	case *XTaskSecp256k1RPubkey:
		ocl.recovertask.add(twp)
	case *XTaskSecp256k1Verify:
		ocl.verifytask.add(twp)
	default:
		return engine.ErrUnsupport
	}

	return nil
}

func (ocl *OclEngine) ProcessBatch(tasks []XTask) ([]XTask, error) {
	if !ocl.Ready() {
		return tasks, ErrNotAlready
	} else {
		e := ocl.doBatchTask(tasks)
		return tasks, e
	}
}

func getCLSourcode() ([]byte, error) {
	statikFS, err := fs.New()
	if err != nil {
		return nil, err
	}
	// Access individual files by their paths.
	r, err := statikFS.Open(clfsname)
	if err != nil {
		return nil, err
	}
	defer r.Close()

	return ioutil.ReadAll(r)
}

func cArrayToGoArray(ca unsafe.Pointer, goArray []byte, size int) {
	p := uintptr(ca)
	for i := 0; i < size; i++ {
		j := *(*byte)(unsafe.Pointer(p))
		goArray[i] = j
		p += unsafe.Sizeof(j)
	}
}

func (ocl *OclEngine) batchOclSecp256RecoverPubkey(msg []OclMsg, sig []OclRecoverableSignature, recpub []OclPubkey) error {
	msgcount := len(msg)
	if msgcount != len(sig) || msgcount != len(recpub) {
		return ErrLengthNotMatch
	}
	//fmt.Println("in oclEngine")
	//for i := 0; i < msgcount; i++ {
	//	fmt.Printf("task ---> %d, msg = %s, rsig = %s\n", i, hex.EncodeToString(msg[i][:]),
	//		hex.EncodeToString(sig[i][:]))
	//}
	ocl.mux.Lock()
	defer ocl.mux.Unlock()

	var (
		cmsg = (*C.ocl_msg)(unsafe.Pointer(&msg[0]))
		csig = (*C.ocl_recoverable_signature)(unsafe.Pointer(&sig[0]))
		cpub = (*C.ocl_pubkey)(unsafe.Pointer(&recpub[0]))
	)

	cret := C.secp256k1_ecdsa_recover_ocl(C.int(msgcount), csig, cmsg, cpub)
	if cret != cl_success {
		return ErrCallFunctionFailed
	}
	//fmt.Println("in oclEngine after recover ")
	//for i := 0; i < msgcount; i++ {
	//	fmt.Printf("task ---> %d, rpub = %s\n", i, hex.EncodeToString(recpub[i][:]))
	//}
	return nil
}

func (ocl *OclEngine) batchOclSecp256Verify(msg []OclMsg, sig []OclSignature, pubkey []OclPubkey, verifyall []C.int) error {
	msgcount := len(msg)
	if msgcount != len(sig) || msgcount != len(pubkey) {
		return ErrLengthNotMatch
	}
	ocl.mux.Lock()
	defer ocl.mux.Unlock()

	var (
		cmsg    = (*C.ocl_msg)(unsafe.Pointer(&msg[0]))
		csig    = (*C.ocl_signature)(unsafe.Pointer(&sig[0]))
		cpub    = (*C.ocl_pubkey)(unsafe.Pointer(&pubkey[0]))
		cverify = (*C.int)(unsafe.Pointer(&verifyall[0]))
	)
	cret := C.secp256k1_ecdsa_verify_ocl(C.int(msgcount), csig, cmsg, cpub, cverify)
	if cret != cl_success {
		return ErrCallFunctionFailed
	}

	return nil
}

func (ocl *OclEngine) OclSecp256RecoverPubkeyS(msg []byte, sig []byte) ([]byte, error) {
	if !ocl.already {
		return nil, ErrNotAlready
	}
	var (
		batchMsg  = make([]OclMsg, 1)
		batchRsig = make([]OclRecoverableSignature, 1)
		batchPub  = make([]OclPubkey, 1)
	)

	copy(batchMsg[0][:], msg)
	copy(batchRsig[0][:], sig)

	err := ocl.batchOclSecp256RecoverPubkey(batchMsg, batchRsig, batchPub)
	if err != nil {
		return nil, err
	}

	return batchPub[0][:], nil
}

func (ocl *OclEngine) OclSecp256VerifyS(msg []byte, sig []byte, pubkey []byte) (bool, error) {
	if !ocl.already {
		return false, ErrNotAlready
	}
	var (
		batchMsg = make([]OclMsg, 1)
		batchSig = make([]OclSignature, 1)
		batchPub = make([]OclPubkey, 1)
		batchRet = make([]C.int, 1)
	)

	copy(batchMsg[0][:], msg)
	copy(batchSig[0][:], sig)
	copy(batchPub[0][:], pubkey)

	err := ocl.batchOclSecp256Verify(batchMsg, batchSig, batchPub, batchRet)
	if err != nil {
		return false, err
	}

	if batchRet[0] == verifyOk {
		return true, nil
	}

	return false, nil
}

func (ocl *OclEngine) routine() {
	for {
		timer := time.NewTicker(time.Millisecond * 100)
		select {
		case <-timer.C:
			if ocl.verifytask.ready() {
				ocl.mux.Lock()
				t := ocl.verifytask
				ocl.verifytask = newBaseTask()
				ocl.mux.Unlock()

				{
					var batch = make([]*XTaskSecp256k1Verify, len(t.tasks))
					var num = len(t.tasks)
					for i := 0; i < num; i++ {
						batch[i] = t.tasks[i].XTask.(*XTaskSecp256k1Verify)
					}
					var (
						batchMsg = make([]OclMsg, num)
						batchSig = make([]OclSignature, num)
						batchPub = make([]OclPubkey, num)
						batchRet = make([]C.int, num)
					)
					for i := 0; i < num; i++ {
						copy(batchMsg[i][:], batch[i].Msg)
						copy(batchSig[i][:], batch[i].Sig)
						copy(batchPub[i][:], batch[i].Pubkey)
					}

					e := ocl.batchOclSecp256Verify(batchMsg, batchSig, batchPub, batchRet)
					if e != nil {
						//Todo: deal with error.
					}

					for i := 0; i < num; i++ {
						batch[i].Verify = (batchRet[i] == 1)
						t.tasks[i].Report()
					}
				}
			}
			if ocl.recovertask.ready() {
				ocl.mux.Lock()
				t := ocl.recovertask
				ocl.recovertask = newBaseTask()
				ocl.mux.Unlock()

				{
					var num = len(t.tasks)
					var batch = make([]*XTaskSecp256k1RPubkey, num)
					for i := 0; i < num; i++ {
						batch[i] = t.tasks[i].XTask.(*XTaskSecp256k1RPubkey)
					}
					var (
						batchMsg  = make([]OclMsg, num)
						batchRsig = make([]OclRecoverableSignature, num)
						batchPub  = make([]OclPubkey, num)
					)
					for i := 0; i < num; i++ {
						copy(batchMsg[i][:], batch[i].Msg)
						copy(batchRsig[i][:], batch[i].Rsig)
					}

					e := ocl.batchOclSecp256RecoverPubkey(batchMsg, batchRsig, batchPub)
					if e != nil {
						//Todo: deal with error.
					}

					for i := 0; i < num; i++ {
						batch[i].Pubkey = make([]byte, pubkeyLength)
						copy(batch[i].Pubkey, batchPub[i][:])
						t.tasks[i].Report()
					}
				}
			}
		case <-ocl.closed:
			return
		}
	}
}
