coupons chaincode first push

parent 5e9151f1
Pipeline #49 failed with stages
This diff is collapsed.
package main
import (
"testing"
)
var privateKey = "59726308197758576002974483116926761969892956433287453489015419592282702209339"
func TestCreateCert(t *testing.T) {
}
This diff is collapsed.
package main
import "github.com/hyperledger/fabric/core/chaincode/shim"
import (
pb "github.com/hyperledger/fabric/protos/peer"
"fmt"
"encoding/json"
)
const KEY = "COUPONS-"
type CPSChainCode struct {
}
func (t *CPSChainCode) Init(stub shim.ChaincodeStubInterface) pb.Response {
return shim.Success([]byte("init successful! "))
}
func (t *CPSChainCode) Invoke(stub shim.ChaincodeStubInterface) pb.Response {
functionName, args := stub.GetFunctionAndParameters()
paths := splitPath(functionName)
return Router(paths, args, stub)
}
func Router(paths, args []string, stub shim.ChaincodeStubInterface) pb.Response {
switch paths[0] {
case "certManagement": //证书管理
return certManagement(paths,args, stub)
case "getCoupons": //机构查询发行的券信息
return getCouponsApi(args, stub)
case "getCouponsByAddress": //通过UTXO地址查看资产信息
return getCouponsByAddressApi(args, stub)
case "subsidies": //补贴请求
return subsidiesApi(args, stub)
case "createTx":
return transactionProcess(paths,args, stub)
default:
return shim.Error(fmt.Sprintf("Unsupported function %s", paths[0]))
}
}
//入参:券id、机构证书、机构签名
func getCouponsApi(args []string, stub shim.ChaincodeStubInterface) pb.Response {
// 解析请求数据
trans,_,_,err := messageToTrans(CreateCoupon,args,stub)
if err != nil{
return shim.Error(err.Error())
}
var coupon Coupons
expand, ok := trans.Value.(string) //判断类型
if !ok {
return shim.Error("Error expanding parameter type,must be string")
}
err = json.Unmarshal([]byte(expand), &coupon)
if err != nil {
return shim.Error("Parameter resolution failed" + err.Error())
}
coupByte,err :=getStateByte(KEY+coupon.CoupId,stub)
if err != nil{
return shim.Error(err.Error())
}
return shim.Success(coupByte)
}
// 入参:utxo地址,证书、交易签名
func getCouponsByAddressApi(args []string, stub shim.ChaincodeStubInterface) pb.Response {
// 解析请求数据
trans,_,_,err := messageToTrans(CreateCoupon,args,stub)
if err != nil{
return shim.Error(err.Error())
}
//查询utxo信息
utxoByte,err :=getStateByte(KEY+trans.From,stub)
if err != nil{
return shim.Error(err.Error())
}
return shim.Success(utxoByte)
}
// 入参:商户公钥、机构证书、交易签名
//出参:交易流水
func subsidiesApi(args []string, stub shim.ChaincodeStubInterface) pb.Response {
// 解析请求数据
trans,_, _,err := messageToTrans(CreateCoupon,args,stub)
// 根据商户公钥查询出该商户所有UTXO记录
arg :="{\"selector\":{\"public_key\":\""+trans.ToPub+"\"}}"
transferInfo,err :=getStateByConditions(arg,stub)
if err != nil {
return shim.Error(err.Error())
}
return shim.Success(transferInfo)
}
func main() {
err := shim.Start(&CPSChainCode{})
if err != nil {
fmt.Printf("Error starting EncCC chaincode: %s", err)
}
}
\ No newline at end of file
This diff is collapsed.
package main
import (
"bytes"
"encoding/json"
"fmt"
"github.com/hyperledger/fabric/core/chaincode/shim"
"time"
)
//券结构
type Coupons struct {
CoupId string `json:"coup_id"` //券id
IssueOrgId string `json:"issue_org_id"` //券发行机构ID
CoupType string `json:"coup_type"` //券类型
CoupValid string `json:"coup_valid"` //券有效期
CoupAvailableTime string `json:"coup_available_time"` //券可用时间段
CoupDescription string `json:"coup_description"` //券使用说明
CoupQuantity int `json:"coup_quantity"` //发行数量
CoupAmount float64 `json:"coup_amount"` //面额
CoupDiscount float64 `json:"coup_discount"` //折扣
CoupTotalAmount float64 `json:"coup_total_amount"` //券总金额
FloorAmount float64 `json:"floor_amount"` //券使用门槛,最低消费金额
UseMerchats []string `json:"use_merchats"` //当前券支持的商户Id
UseMerchantsType *MerchantsType `json:"use_merchants_type"`
Extensions interface{} `json:"extensions"` //拓展信息,机构自定义
}
type MerchantsType struct {
UseMerchatsType []int `json:"use_merchats_type"` //当前券支持的商户类型
OrgMerchatsType []string `json:"org_merchats_type"` //支持的机构签发类别,len 为0 表示只支持自己机构签发的
IsAll []bool `json:"is_all"` //当前券支持的商户类型是只支持自己签发的商户还是支持全部机构签发的商户
}
//券可用时间段
type CoupAvailable struct {
AvailableDay string `json:"available_day"`
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
}
//UTXO结构
type Utxo struct {
Address string `json:"address"` //Utxo记录地址
PublicKey string `json:"public_key"` //公钥
UserId string `json:"user_id"` //用户Id,代替公钥
Status string `json:"status"` //utxo状态.0发行 1分发 2领取 3 已使用
CoupType string `json:"coup_type"` //券类型
CoupQuantity int `json:"coup_quantity"` //券总量
CoupAmount float64 `json:"coup_amount"` //面额
CoupDiscount float64 `json:"coup_discount"` //折扣
FloorAmount float64 `json:"floor_amount"` //券使用门槛,最低消费金额
PreAddress []string `json:"pre_address"` //前置utxo地址
CoupId string `json:"coup_id"` //券id
}
//订单结构
type Order struct {
OrderId string `json:"order_id"`
CreateTime time.Time `json:"create_time"`
OrderMoney float32 `json:"order_money"`
}
//用户信息
type UserInfo struct {
Uid string `json:"uid"` //用户id
Name string `json:"name"` //用户名
Number string `json:"number"` //身份证号
Phone string `json:"phone"` //手机号
CreateTime time.Time `json:"create_time"` //创建时间
}
//商户信息
type Merchant struct {
Mid string `json:"mid"` //商户id
Type string `json:"type"` //商户类型
Name string `json:"name"` //商户名称
Address string `json:"address"` //商户地址
Cert string `json:"cert"` //证书
CreateTime time.Time `json:"create_time"` //创建时间
}
//渠道信息
type Ditch struct {
Did string `json:"did"` //渠道id
Name string `json:"name"` //渠道名称
Cert string `json:"cert"` //证书
CreateTime time.Time `json:"create_time"` //创建时间
}
//机构信息
type Orgnazation struct {
Oid string `json:"oid"`
Name string `json:"name"` //渠道名称
Cert string `json:"cert"` //证书
CreateTime time.Time `json:"create_time"` //创建时间
}
// 前端传递的数据
type Message struct {
UserId string `json:"user_id"` // 交易发起人的用户Id
Data string `json:"data"` // 交易参数
Sign string `json:"sign"` // 交易签名信息,if是商户交易,这里是券公钥
}
//交易体
type Trans struct {
Value interface{} `json:"value"` // 拓展字段,1、用发行卡券时代表卡券信息。2、使用卡券时代表订单信息.3、查询卡券信息时代表卡券编号。
From string `json:"fromAddress"` //发起方utxo地址
To string `json:"toAddress"` // 接收方utxo地址
ToPub string `json:"toPublicKey"` //接收者公钥
}
func putStateUtxo(key string, utxo *Utxo, stub shim.ChaincodeStubInterface) error {
valueByte, err := json.Marshal(utxo)
if err != nil {
return fmt.Errorf(" %s json marshal data fail,err: %s", key, err)
}
err = stub.PutState(key, valueByte)
if err != nil {
return fmt.Errorf("putState %s data fail,err: %s", key, err)
}
return nil
}
func putStateStruct(key string, value interface{}, stub shim.ChaincodeStubInterface) error {
valueByte, err := json.Marshal(value)
if err != nil {
return fmt.Errorf(" %s json marshal data fail,err: %s", key, err)
}
err = stub.PutState(key, valueByte)
if err != nil {
return fmt.Errorf("putState %s data fail,err: %s", key, err)
}
return nil
}
func putStateByte(key string, value []byte, stub shim.ChaincodeStubInterface) error {
err := stub.PutState(key, value)
if err != nil {
return fmt.Errorf("putState %s data fail,err: %s", key, err)
}
return nil
}
func putStateCoup(key string, coupons Coupons, stub shim.ChaincodeStubInterface) error {
valueByte, err := json.Marshal(coupons)
if err != nil {
return fmt.Errorf(" %s json marshal data fail,err: %s", key, err)
}
err = stub.PutState(key, valueByte)
if err != nil {
return fmt.Errorf("putState %s data fail,err: %s", key, err)
}
return nil
}
func getStateUtxo(key string, stub shim.ChaincodeStubInterface) (*Utxo, error) {
var utxo *Utxo
utxoByteInfo, err := stub.GetState(key)
if err != nil {
return nil, err
}
if utxoByteInfo == nil {
return nil, fmt.Errorf("The query information is empty")
}
err = json.Unmarshal(utxoByteInfo, &utxo)
if err != nil {
return nil, fmt.Errorf("Byte array serialization failed")
}
return utxo, nil
}
func getStateCoupons(key string, stub shim.ChaincodeStubInterface) (*Coupons, error) {
coupons := &Coupons{}
couponsByteInfo, err := stub.GetState(key)
if err != nil {
return nil, err
}
if couponsByteInfo == nil {
return nil, fmt.Errorf("The query information is empty")
}
err = json.Unmarshal(couponsByteInfo, coupons)
if err != nil {
return nil, fmt.Errorf("Byte array serialization failed")
}
return coupons, nil
}
func getStateByte(key string, stub shim.ChaincodeStubInterface) ([]byte, error) {
byteInfo, err := stub.GetState(key)
if err != nil {
return nil, err
}
if byteInfo == nil {
return nil, fmt.Errorf("The query information is empty")
}
return byteInfo, nil
}
func getStateStruct(key string, value interface{}, stub shim.ChaincodeStubInterface) (err error) {
byteInfo, err := stub.GetState(key)
if err != nil {
return err
}
if byteInfo == nil {
return fmt.Errorf("The query information is empty")
}
return json.Unmarshal(byteInfo, value)
}
func getStateByConditions(key string, stub shim.ChaincodeStubInterface) ([]byte, error) {
resultsIterator, err := stub.GetQueryResult(key)
defer resultsIterator.Close()
if err != nil {
return nil, err
}
var buffer bytes.Buffer
buffer.WriteString("[")
bArrayMemberAlreadyWritten := false
for resultsIterator.HasNext() {
queryResponse, err := resultsIterator.Next()
fmt.Println(queryResponse.String())
if err != nil {
return nil, err
}
if bArrayMemberAlreadyWritten == true {
buffer.WriteString(",")
}
buffer.WriteString("{\"Key\":")
buffer.WriteString("\"")
buffer.WriteString(queryResponse.Key)
buffer.WriteString("\"")
buffer.WriteString(", \"Record\":")
buffer.WriteString(string(queryResponse.Value))
buffer.WriteString("}")
bArrayMemberAlreadyWritten = true
}
buffer.WriteString("]")
return buffer.Bytes(), nil
}
//判断key是否存在
func checkKey(key string, stub shim.ChaincodeStubInterface) (bool, error) {
info, err := stub.GetState(key)
if err != nil {
return false, err
}
if info == nil {
return false, nil
}
return true, nil
}
This diff is collapsed.
package main
import (
"strings"
"fmt"
"encoding/json"
"github.com/hyperledger/fabric/core/chaincode/shim"
)
func splitPath(key string) []string {
key = strings.Trim(key, "/ ")
if key == "" {
return []string{}
}
return strings.Split(key, "/")
}
//解析前端数据 证书验证 和 签名验证
func messageToTrans(operation TxType,args []string,stub shim.ChaincodeStubInterface) (*Trans,*Message, string ,error) {
if len(args) != 3 {
return nil,nil, "",fmt.Errorf("put data operation expected more than 3 parameters! ")
}
message := &Message{
UserId: args[0],
Data: args[1],
Sign: args[2],
}
if message.UserId == "" || message.Sign == "" || message.Data == "" {
return nil,nil,"" ,fmt.Errorf("Parameter exception, cannot be null!")
}
//验证证书和签名
pubKey ,err := CheckCertSignature(message,operation,stub)
if err != nil{
return nil,nil,"",err
}
fmt.Println("1",message.Data)
var trans Trans
err = json.Unmarshal([]byte(message.Data), &trans)
if err != nil {
return nil,nil,"", fmt.Errorf("参数解析失败"+err.Error())
}
return &trans,message,pubKey , nil
}
This diff is collapsed.
/* Created by "go tool cgo" - DO NOT EDIT. */
/* package command-line-arguments */
#line 1 "cgo-builtin-prolog"
#include <stddef.h> /* for ptrdiff_t below */
#ifndef GO_CGO_EXPORT_PROLOGUE_H
#define GO_CGO_EXPORT_PROLOGUE_H
typedef struct { const char *p; ptrdiff_t n; } _GoString_;
#endif
/* Start of preamble from import "C" comments. */
/* End of preamble from import "C" comments. */
/* Start of boilerplate cgo prologue. */
#line 1 "cgo-gcc-export-header-prolog"
#ifndef GO_CGO_PROLOGUE_H
#define GO_CGO_PROLOGUE_H
typedef signed char GoInt8;
typedef unsigned char GoUint8;
typedef short GoInt16;
typedef unsigned short GoUint16;
typedef int GoInt32;
typedef unsigned int GoUint32;
typedef long long GoInt64;
typedef unsigned long long GoUint64;
typedef GoInt64 GoInt;
typedef GoUint64 GoUint;
typedef __SIZE_TYPE__ GoUintptr;
typedef float GoFloat32;
typedef double GoFloat64;
typedef float _Complex GoComplex64;
typedef double _Complex GoComplex128;
/*
static assertion to make sure the file is being used on architecture
at least with matching size of GoInt.
*/
typedef char _check_for_64_bit_pointer_matching_GoInt[sizeof(void*)==64/8 ? 1:-1];
typedef _GoString_ GoString;
typedef void *GoMap;
typedef void *GoChan;
typedef struct { void *t; void *v; } GoInterface;
typedef struct { void *data; GoInt len; GoInt cap; } GoSlice;
#endif
/* End of boilerplate cgo prologue. */
#ifdef __cplusplus
extern "C" {
#endif
extern char* C_Hash256(char* p0);
extern char* C_Hash256Bysha3(char* p0);
extern char* C_Sign(char* p0, char* p1);
extern char* C_GenerateKey();
extern char* C_FromPrv(char* p0);
extern GoUint8 C_VerifySignature(char* p0, char* p1, char* p2);
#ifdef __cplusplus
}
#endif
/*
Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sm2
import (
"encoding/pem"
"errors"
"io/ioutil"
"os"
"runtime"
"sync"
)
// Possible certificate files; stop after finding one.
var certFiles = []string{
"/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu/Gentoo etc.
"/etc/pki/tls/certs/ca-bundle.crt", // Fedora/RHEL 6
"/etc/ssl/ca-bundle.pem", // OpenSUSE
"/etc/pki/tls/cacert.pem", // OpenELEC
"/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem", // CentOS/RHEL 7
}
// CertPool is a set of certificates.
type CertPool struct {
bySubjectKeyId map[string][]int
byName map[string][]int
certs []*Certificate
}
// NewCertPool returns a new, empty CertPool.
func NewCertPool() *CertPool {
return &CertPool{
bySubjectKeyId: make(map[string][]int),
byName: make(map[string][]int),
}
}
// Possible directories with certificate files; stop after successfully
// reading at least one file from a directory.
var certDirectories = []string{
"/etc/ssl/certs", // SLES10/SLES11, https://golang.org/issue/12139
"/system/etc/security/cacerts", // Android
}
var (
once sync.Once
systemRoots *CertPool
systemRootsErr error
)
func systemRootsPool() *CertPool {
once.Do(initSystemRoots)
return systemRoots
}
func initSystemRoots() {
systemRoots, systemRootsErr = loadSystemRoots()
}
func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
return nil, nil
}
func loadSystemRoots() (*CertPool, error) {
roots := NewCertPool()
var firstErr error
for _, file := range certFiles {
data, err := ioutil.ReadFile(file)
if err == nil {
roots.AppendCertsFromPEM(data)
return roots, nil
}
if firstErr == nil && !os.IsNotExist(err) {
firstErr = err
}
}
for _, directory := range certDirectories {
fis, err := ioutil.ReadDir(directory)
if err != nil {
if firstErr == nil && !os.IsNotExist(err) {
firstErr = err
}
continue
}
rootsAdded := false
for _, fi := range fis {
data, err := ioutil.ReadFile(directory + "/" + fi.Name())
if err == nil && roots.AppendCertsFromPEM(data) {
rootsAdded = true
}
}
if rootsAdded {
return roots, nil
}
}
return nil, firstErr
}
// SystemCertPool returns a copy of the system cert pool.
//
// Any mutations to the returned pool are not written to disk and do
// not affect any other pool.
func SystemCertPool() (*CertPool, error) {
if runtime.GOOS == "windows" {
// Issue 16736, 18609:
return nil, errors.New("crypto/x509: system root pool is not available on Windows")
}
return loadSystemRoots()
}
// findVerifiedParents attempts to find certificates in s which have signed the
// given certificate. If any candidates were rejected then errCert will be set
// to one of them, arbitrarily, and err will contain the reason that it was
// rejected.
func (s *CertPool) findVerifiedParents(cert *Certificate) (parents []int, errCert *Certificate, err error) {
if s == nil {
return
}
var candidates []int
if len(cert.AuthorityKeyId) > 0 {
candidates = s.bySubjectKeyId[string(cert.AuthorityKeyId)]
}
if len(candidates) == 0 {
candidates = s.byName[string(cert.RawIssuer)]
}
for _, c := range candidates {
if err = cert.CheckSignatureFrom(s.certs[c]); err == nil {
parents = append(parents, c)
} else {
errCert = s.certs[c]
}
}
return
}
func (s *CertPool) contains(cert *Certificate) bool {
if s == nil {
return false
}
candidates := s.byName[string(cert.RawSubject)]
for _, c := range candidates {
if s.certs[c].Equal(cert) {
return true
}
}
return false
}
// AddCert adds a certificate to a pool.
func (s *CertPool) AddCert(cert *Certificate) {
if cert == nil {
panic("adding nil Certificate to CertPool")
}
// Check that the certificate isn't being added twice.
if s.contains(cert) {
return
}
n := len(s.certs)
s.certs = append(s.certs, cert)
if len(cert.SubjectKeyId) > 0 {
keyId := string(cert.SubjectKeyId)
s.bySubjectKeyId[keyId] = append(s.bySubjectKeyId[keyId], n)
}
name := string(cert.RawSubject)
s.byName[name] = append(s.byName[name], n)
}
// AppendCertsFromPEM attempts to parse a series of PEM encoded certificates.
// It appends any certificates found to s and reports whether any certificates
// were successfully parsed.
//
// On many Linux systems, /etc/ssl/cert.pem will contain the system wide set
// of root CAs in a format suitable for this function.
func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
for len(pemCerts) > 0 {
var block *pem.Block
block, pemCerts = pem.Decode(pemCerts)
if block == nil {
break
}
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
continue
}
cert, err := ParseCertificate(block.Bytes)
if err != nil {
continue
}
s.AddCert(cert)
ok = true
}
return
}
// Subjects returns a list of the DER-encoded subjects of
// all of the certificates in the pool.
func (s *CertPool) Subjects() [][]byte {
res := make([][]byte, len(s.certs))
for i, c := range s.certs {
res[i] = c.RawSubject
}
return res
}
This diff is collapsed.
/*
Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sm2
import (
"crypto/rsa"
"encoding/asn1"
"errors"
"math/big"
)
// pkcs1PrivateKey is a structure which mirrors the PKCS#1 ASN.1 for an RSA private key.
type pkcs1PrivateKey struct {
Version int
N *big.Int
E int
D *big.Int
P *big.Int
Q *big.Int
// We ignore these values, if present, because rsa will calculate them.
Dp *big.Int `asn1:"optional"`
Dq *big.Int `asn1:"optional"`
Qinv *big.Int `asn1:"optional"`
AdditionalPrimes []pkcs1AdditionalRSAPrime `asn1:"optional,omitempty"`
}
type pkcs1AdditionalRSAPrime struct {
Prime *big.Int
// We ignore these values because rsa will calculate them.
Exp *big.Int
Coeff *big.Int
}
// ParsePKCS1PrivateKey returns an RSA private key from its ASN.1 PKCS#1 DER encoded form.
func ParsePKCS1PrivateKey(der []byte) (*rsa.PrivateKey, error) {
var priv pkcs1PrivateKey
rest, err := asn1.Unmarshal(der, &priv)
if len(rest) > 0 {
return nil, asn1.SyntaxError{Msg: "trailing data"}
}
if err != nil {
return nil, err
}
if priv.Version > 1 {
return nil, errors.New("x509: unsupported private key version")
}
if priv.N.Sign() <= 0 || priv.D.Sign() <= 0 || priv.P.Sign() <= 0 || priv.Q.Sign() <= 0 {
return nil, errors.New("x509: private key contains zero or negative value")
}
key := new(rsa.PrivateKey)
key.PublicKey = rsa.PublicKey{
E: priv.E,
N: priv.N,
}
key.D = priv.D
key.Primes = make([]*big.Int, 2+len(priv.AdditionalPrimes))
key.Primes[0] = priv.P
key.Primes[1] = priv.Q
for i, a := range priv.AdditionalPrimes {
if a.Prime.Sign() <= 0 {
return nil, errors.New("x509: private key contains zero or negative prime")
}
key.Primes[i+2] = a.Prime
// We ignore the other two values because rsa will calculate
// them as needed.
}
err = key.Validate()
if err != nil {
return nil, err
}
key.Precompute()
return key, nil
}
// MarshalPKCS1PrivateKey converts a private key to ASN.1 DER encoded form.
func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte {
key.Precompute()
version := 0
if len(key.Primes) > 2 {
version = 1
}
priv := pkcs1PrivateKey{
Version: version,
N: key.N,
E: key.PublicKey.E,
D: key.D,
P: key.Primes[0],
Q: key.Primes[1],
Dp: key.Precomputed.Dp,
Dq: key.Precomputed.Dq,
Qinv: key.Precomputed.Qinv,
}
priv.AdditionalPrimes = make([]pkcs1AdditionalRSAPrime, len(key.Precomputed.CRTValues))
for i, values := range key.Precomputed.CRTValues {
priv.AdditionalPrimes[i].Prime = key.Primes[2+i]
priv.AdditionalPrimes[i].Exp = values.Exp
priv.AdditionalPrimes[i].Coeff = values.Coeff
}
b, _ := asn1.Marshal(priv)
return b
}
// rsaPublicKey reflects the ASN.1 structure of a PKCS#1 public key.
type rsaPublicKey struct {
N *big.Int
E int
}
This diff is collapsed.
/*
Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sm2
// reference to ecdsa
import (
"bytes"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/elliptic"
"crypto/rand"
"crypto/sha512"
"encoding/binary"
"errors"
"io"
"math/big"
"github.com/TMChain/go-TMChain_test/crypto/crypto_interface"
"github.com/chaincodecert/sm_crypto/sm3"
)
const (
aesIV = "IV for <SM2> CTR"
)
type PublicKey struct {
elliptic.Curve
X, Y *big.Int
}
type PrivateKey struct {
PublicKey
D *big.Int
}
// The SM2's private key contains the public key
func (priv *PrivateKey) Public() crypto.PublicKey {
return &priv.PublicKey
}
// sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
func (priv *PrivateKey) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) ([]byte, error) {
r, s, err := Sign(priv, msg)
if err != nil {
return nil, err
}
sig := icrypto.FormatSm2SigTo64(r, s)
return sig, nil
}
func SM2Sign(priv *PrivateKey, hash []byte) ([]byte, error) {
r, s, err := Sign(priv, hash)
if err != nil {
return nil, err
}
sig := icrypto.FormatSm2SigTo64(r, s)
return sig, nil
}
func (priv *PrivateKey) Decrypt(data []byte) ([]byte, error) {
return Decrypt(priv, data)
}
func VerifySign(pub *PublicKey, msg []byte, sign []byte) bool {
return pub.Verify(msg, sign)
}
func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
//兼容处理ecdsa签名数据多余V的byte
if len(sign) == 65 {
sign = sign[:64]
}
r := new(big.Int).SetBytes(sign[:32])
s := new(big.Int).SetBytes(sign[32:])
return Verify(pub, msg, r, s)
}
func (pub *PublicKey) Encrypt(data []byte) ([]byte, error) {
return Encrypt(pub, data)
}
var one = new(big.Int).SetInt64(1)
func intToBytes(x int) []byte {
var buf = make([]byte, 4)
binary.BigEndian.PutUint32(buf, uint32(x))
return buf
}
func kdf(x, y []byte, length int) ([]byte, bool) {
var c []byte
ct := 1
h := sm3.New()
x = append(x, y...)
for i, j := 0, (length+31)/32; i < j; i++ {
h.Reset()
h.Write(x)
h.Write(intToBytes(ct))
hash := h.Sum(nil)
if i+1 == j && length%32 != 0 {
c = append(c, hash[:length%32]...)
} else {
c = append(c, hash...)
}
ct++
}
for i := 0; i < length; i++ {
if c[i] != 0 {
return c, true
}
}
return c, false
}
func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
params := c.Params()
b := make([]byte, params.BitSize/8+8)
_, err = io.ReadFull(rand, b)
if err != nil {
return
}
k = new(big.Int).SetBytes(b)
n := new(big.Int).Sub(params.N, one)
k.Mod(k, n)
k.Add(k, one)
return
}
func GenerateKey() (*PrivateKey, error) {
c := P256Sm2()
k, err := randFieldElement(c, rand.Reader)
if err != nil {
return nil, err
}
priv := new(PrivateKey)
priv.PublicKey.Curve = c
priv.D = k
priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
return priv, nil
}
func GenerateKey2(rand io.Reader) (*PrivateKey, error) {
c := P256Sm2()
k, err := randFieldElement(c, rand)
if err != nil {
return nil, err
}
priv := new(PrivateKey)
priv.PublicKey.Curve = c
priv.D = k
priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
return priv, nil
}
var errZeroParam = errors.New("zero parameter")
func Sign(priv *PrivateKey, hash []byte) (r, s *big.Int, err error) {
entropylen := (priv.Curve.Params().BitSize + 7) / 16
if entropylen > 32 {
entropylen = 32
}
entropy := make([]byte, entropylen)
_, err = io.ReadFull(rand.Reader, entropy)
if err != nil {
return
}
// Initialize an SHA-512 hash context; digest ...
md := sha512.New()
md.Write(priv.D.Bytes()) // the private key,
md.Write(entropy) // the entropy,
md.Write(hash) // and the input hash;
key := md.Sum(nil)[:32] // and compute ChopMD-256(SHA-512),
// which is an indifferentiable MAC.
// Create an AES-CTR instance to use as a CSPRNG.
block, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
// Create a CSPRNG that xors a stream of zeros with
// the output of the AES-CTR instance.
csprng := cipher.StreamReader{
R: zeroReader,
S: cipher.NewCTR(block, []byte(aesIV)),
}
// See [NSA] 3.4.1
c := priv.PublicKey.Curve
N := c.Params().N
if N.Sign() == 0 {
return nil, nil, errZeroParam
}
var k *big.Int
e := new(big.Int).SetBytes(hash)
for { // 调整算法细节以实现SM2
for {
k, err = randFieldElement(c, csprng)
if err != nil {
r = nil
return
}
r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
r.Add(r, e)
r.Mod(r, N)
if r.Sign() != 0 {
break
}
if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
break
}
}
rD := new(big.Int).Mul(priv.D, r)
s = new(big.Int).Sub(k, rD)
d1 := new(big.Int).Add(priv.D, one)
d1Inv := new(big.Int).ModInverse(d1, N)
s.Mul(s, d1Inv)
s.Mod(s, N)
if s.Sign() != 0 {
break
}
}
return
}
func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
c := pub.Curve
N := c.Params().N
if r.Sign() <= 0 || s.Sign() <= 0 {
return false
}
if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
return false
}
// 调整算法细节以实现SM2
t := new(big.Int).Add(r, s)
t.Mod(t, N)
if N.Sign() == 0 {
return false
}
var x *big.Int
x1, y1 := c.ScalarBaseMult(s.Bytes())
x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
x, _ = c.Add(x1, y1, x2, y2)
e := new(big.Int).SetBytes(hash)
x.Add(x, e)
x.Mod(x, N)
return x.Cmp(r) == 0
}
// 32byte
var zeroByteSlice = []byte{
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
}
/*
* sm2密文结构如下:
* x
* y
* hash
* CipherText
*/
func Encrypt(pub *PublicKey, data []byte) ([]byte, error) {
lenx1 := 0
leny1 := 0
lenx2 := 0
leny2 := 0
length := len(data)
for {
c := []byte{}
curve := pub.Curve
k, err := randFieldElement(curve, rand.Reader)
if err != nil {
return nil, err
}
x1, y1 := curve.ScalarBaseMult(k.Bytes())
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
lenx1 = len(x1.Bytes())
leny1 = len(y1.Bytes())
lenx2 = len(x2.Bytes())
leny2 = len(y2.Bytes())
if lenx1 < 32 {
c = append(c, zeroByteSlice[:(32-lenx1)]...)
}
c = append(c, x1.Bytes()...) // x分量
if leny1 < 32 {
c = append(c, zeroByteSlice[:(32-leny1)]...)
}
c = append(c, y1.Bytes()...) // y分量
tm := []byte{}
if lenx2 < 32 {
tm = append(tm, zeroByteSlice[:(32-lenx2)]...)
}
tm = append(tm, x2.Bytes()...)
tm = append(tm, data...)
if leny2 < 32 {
tm = append(tm, zeroByteSlice[:(32-leny2)]...)
}
tm = append(tm, y2.Bytes()...)
h := sm3.Sm3Sum(tm)
c = append(c, h...)
ct, ok := kdf(x2.Bytes(), y2.Bytes(), length) // 密文
if !ok {
continue
}
c = append(c, ct...)
for i := 0; i < length; i++ {
c[96+i] ^= data[i]
}
return c, nil
}
}
func Decrypt(priv *PrivateKey, data []byte) ([]byte, error) {
length := len(data) - 96
curve := priv.Curve
x := new(big.Int).SetBytes(data[:32])
y := new(big.Int).SetBytes(data[32:64])
x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
c, ok := kdf(x2.Bytes(), y2.Bytes(), length)
if !ok {
return nil, errors.New("Decrypt: failed to decrypt")
}
for i := 0; i < length; i++ {
c[i] ^= data[i+96]
}
tm := []byte{}
tm = append(tm, x2.Bytes()...)
tm = append(tm, c...)
tm = append(tm, y2.Bytes()...)
h := sm3.Sm3Sum(tm)
if bytes.Compare(h, data[64:96]) != 0 {
return c, errors.New("Decrypt: failed to decrypt")
}
return c, nil
}
func getLastBit(a *big.Int) uint {
return a.Bit(0)
}
// 32byte
func zeroByteSlice2() []byte {
return []byte{
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
}
}
// Compress transform publickey point struct to 33 bytes publickey.
func Compress(a *PublicKey) []byte {
buf := []byte{}
yp := getLastBit(a.Y)
buf = append(buf, a.X.Bytes()...)
if n := len(a.X.Bytes()); n < 32 {
buf = append(zeroByteSlice2()[:(32-n)], buf...)
}
// RFC: GB/T 32918.1-2016 4.2.9
// if yp = 0, buf = 02||x
// if yp = 0, buf = 03||x
if yp == uint(0) {
buf = append([]byte{byte(2)}, buf...)
}
if yp == uint(1) {
buf = append([]byte{byte(3)}, buf...)
}
return buf
}
// Decompress transform 33 bytes publickey to publickey point struct.
func Decompress(a []byte) *PublicKey {
var aa, xx, xx3 sm2P256FieldElement
P256Sm2()
x := new(big.Int).SetBytes(a[1:])
curve := sm2P256
sm2P256FromBig(&xx, x)
sm2P256Square(&xx3, &xx) // x3 = x ^ 2
sm2P256Mul(&xx3, &xx3, &xx) // x3 = x ^ 2 * x
sm2P256Mul(&aa, &curve.a, &xx) // a = a * x
sm2P256Add(&xx3, &xx3, &aa)
sm2P256Add(&xx3, &xx3, &curve.b)
y2 := sm2P256ToBig(&xx3)
y := new(big.Int).ModSqrt(y2, sm2P256.P)
// RFC: GB/T 32918.1-2016 4.2.10
// if a[0] = 02, getLastBit(y) = 0
// if a[0] = 03, getLastBit(y) = 1
// if yp = 0, buf = 03||x
if getLastBit(y) != uint(a[0])-2 {
y.Sub(sm2P256.P, y)
}
return &PublicKey{
Curve: P256Sm2(),
X: x,
Y: y,
}
}
type zr struct {
io.Reader
}
func (z *zr) Read(dst []byte) (n int, err error) {
for i := range dst {
dst[i] = 0
}
return len(dst), nil
}
var zeroReader = &zr{}
This diff is collapsed.
This diff is collapsed.
package sm3
// Sum256 returns the SM3 digest of the data.
func Sum256(data []byte) (digest [32]byte) {
hash := Sm3Sum(data)
copy(digest[:], hash)
return
}
// Sum calculate data into hash
func Sum(hash, data []byte) {
tmp := Sm3Sum(data)
copy(hash, tmp)
}
This diff is collapsed.
/*
Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sm3
import (
"fmt"
"io/ioutil"
"log"
"os"
"testing"
)
func byteToString(b []byte) string {
ret := ""
for i := 0; i < len(b); i++ {
ret += fmt.Sprintf("%02x", b[i])
}
fmt.Println("ret = ", ret)
return ret
}
func TestSm3(t *testing.T) {
msg := []byte("test")
err := ioutil.WriteFile("ifile", msg, os.FileMode(0644)) // 生成测试文件
if err != nil {
log.Fatal(err)
}
msg, err = ioutil.ReadFile("ifile")
if err != nil {
log.Fatal(err)
}
hw := New()
hw.Write(msg)
hash := hw.Sum(nil)
fmt.Println(hash)
fmt.Printf("hash = %d\n", len(hash))
fmt.Printf("%s\n", byteToString(hash))
hash1 := Sm3Sum(msg)
fmt.Println(hash1)
fmt.Printf("%s\n", byteToString(hash1))
}
func BenchmarkSm3(t *testing.B) {
t.ReportAllocs()
msg := []byte("test")
hw := New()
for i := 0; i < t.N; i++ {
hw.Sum(nil)
Sm3Sum(msg)
}
}
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment