文章目录
- 一、项目背景
- 二、环境极其依赖
- 三、具体功能
- 1.Python生成密钥对
- 2.java生成密钥对
- 3.Python加签验签
- 4.java加签验签
- 四、遇到的问题
- 五、解决方案
一、项目背景
Python对接Java接口互相SM2加签验签
二、环境极其依赖
- python环境
pip3 install gmssl - java环境
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcprov-jdk15to18</artifactId>
<version>1.66</version>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.8.18</version>
</dependency>
三、具体功能
1.Python生成密钥对
from gmssl import sm2 as SM2
from gmssl import func as GMFunc
from random import SystemRandom
from base64 import b64encode, b64decode
class CurveFp:
def __init__(self, A, B, P, N, Gx, Gy, name):
self.A = A
self.B = B
self.P = P
self.N = N
self.Gx = Gx
self.Gy = Gy
self.name = name
class SM2Key:
sm2p256v1 = CurveFp(
name="sm2p256v1",
A=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC,
B=0x28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93,
P=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF,
N=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123,
Gx=0x32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7,
Gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0
)
@staticmethod
def multiply(a, n, N, A, P):
return SM2Key.fromJacobian(SM2Key.jacobianMultiply(SM2Key.toJacobian(a), n, N, A, P), P)
@staticmethod
def add(a, b, A, P):
return SM2Key.fromJacobian(SM2Key.jacobianAdd(SM2Key.toJacobian(a), SM2Key.toJacobian(b), A, P), P)
@staticmethod
def inv(a, n):
if a == 0:
return 0
lm, hm = 1, 0
low, high = a % n, n
while low > 1:
r = high // low
nm, new = hm - lm * r, high - low * r
lm, low, hm, high = nm, new, lm, low
return lm % n
@staticmethod
def toJacobian(Xp_Yp):
Xp, Yp = Xp_Yp
return Xp, Yp, 1
@staticmethod
def fromJacobian(Xp_Yp_Zp, P):
Xp, Yp, Zp = Xp_Yp_Zp
z = SM2Key.inv(Zp, P)
return (Xp * z ** 2) % P, (Yp * z ** 3) % P
@staticmethod
def jacobianDouble(Xp_Yp_Zp, A, P):
Xp, Yp, Zp = Xp_Yp_Zp
if not Yp:
return 0, 0, 0
ysq = (Yp ** 2) % P
S = (4 * Xp * ysq) % P
M = (3 * Xp ** 2 + A * Zp ** 4) % P
nx = (M ** 2 - 2 * S) % P
ny = (M * (S - nx) - 8 * ysq ** 2) % P
nz = (2 * Yp * Zp) % P
return nx, ny, nz
@staticmethod
def jacobianAdd(Xp_Yp_Zp, Xq_Yq_Zq, A, P):
Xp, Yp, Zp = Xp_Yp_Zp
Xq, Yq, Zq = Xq_Yq_Zq
if not Yp:
return Xq, Yq, Zq
if not Yq:
return Xp, Yp, Zp
U1 = (Xp * Zq ** 2) % P
U2 = (Xq * Zp ** 2) % P
S1 = (Yp * Zq ** 3) % P
S2 = (Yq * Zp ** 3) % P
if U1 == U2:
if S1 != S2:
return 0, 0, 1
return SM2Key.jacobianDouble((Xp, Yp, Zp), A, P)
H = U2 - U1
R = S2 - S1
H2 = (H * H) % P
H3 = (H * H2) % P
U1H2 = (U1 * H2) % P
nx = (R ** 2 - H3 - 2 * U1H2) % P
ny = (R * (U1H2 - nx) - S1 * H3) % P
nz = (H * Zp * Zq) % P
return nx, ny, nz
@staticmethod
def jacobianMultiply(Xp_Yp_Zp, n, N, A, P):
Xp, Yp, Zp = Xp_Yp_Zp
if Yp == 0 or n == 0:
return (0, 0, 1)
if n == 1:
return (Xp, Yp, Zp)
if n < 0 or n >= N:
return SM2Key.jacobianMultiply((Xp, Yp, Zp), n % N, N, A, P)
if (n % 2) == 0:
return SM2Key.jacobianDouble(SM2Key.jacobianMultiply((Xp, Yp, Zp), n // 2, N, A, P), A, P)
if (n % 2) == 1:
mv = SM2Key.jacobianMultiply((Xp, Yp, Zp), n // 2, N, A, P)
return SM2Key.jacobianAdd(SM2Key.jacobianDouble(mv, A, P), (Xp, Yp, Zp), A, P)
class PrivateKey:
def __init__(self, curve=SM2Key.sm2p256v1, secret=None):
self.curve = curve
self.secret = secret or SystemRandom().randrange(1, curve.N)
def PublicKey(self):
curve = self.curve
xPublicKey, yPublicKey = SM2Key.multiply((curve.Gx, curve.Gy), self.secret, A=curve.A, P=curve.P, N=curve.N)
return PublicKey(xPublicKey, yPublicKey, curve)
def ToString(self):
return "{}".format(str(hex(self.secret))[2:].zfill(64))
class PublicKey:
def __init__(self, x, y, curve):
self.x = x
self.y = y
self.curve = curve
def ToString(self, compressed=True):
return '04' + {
True: str(hex(self.x))[2:],
False: "{}{}".format(str(hex(self.x))[2:].zfill(64), str(hex(self.y))[2:].zfill(64))
}.get(compressed)
class SM2Util:
def __init__(self, pub_key=None, pri_key=None):
self.pub_key = pub_key
self.pri_key = pri_key
self.sm2 = SM2.CryptSM2(public_key=self.pub_key, private_key=self.pri_key)
def Encrypt(self, data):
info = self.sm2.encrypt(data.encode())
return b64encode(info).decode()
def Decrypt(self, data):
info = b64decode(data.encode())
return self.sm2.decrypt(info).decode()
def Sign(self, data):
random_hex_str = GMFunc.random_hex(self.sm2.para_len)
sign = self.sm2.sign(data.encode(), random_hex_str)
return sign
def Verify(self, data, sign):
return self.sm2.verify(sign, data.encode())
@staticmethod
def GenKeyPair():
pri = PrivateKey()
pub = pri.PublicKey()
return pri.ToString(), pub.ToString(compressed=False)
def main():
"""
主函数
:return:
"""
import random
vs = '我是笨蛋'
data = ''.join([vs[random.randint(0, len(vs) - 1)] for i in range(500)])
print('原数据:{}'.format(data))
e = SM2Util.GenKeyPair()
print('私钥1:{} 公钥1:{}', (e[0], e[1]))
print('私钥:{} 公钥:{}'.format(e[0], e[1]))
sm2 = SM2Util(pri_key=e[0], pub_key=e[1][2:])
sign = sm2.Sign(data)
print('签名:{} 验签:{}'.format(sign, sm2.Verify(data, sign)))
cipher = sm2.Encrypt(data)
print('加密:{}\n解密:{}'.format(cipher, sm2.Decrypt(cipher)))
if __name__ == '__main__':
main()
2.java生成密钥对
package SM;
import cn.hutool.core.util.HexUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.crypto.BCUtil;
import cn.hutool.crypto.SecureUtil;
import cn.hutool.crypto.SmUtil;
import cn.hutool.crypto.asymmetric.KeyType;
import cn.hutool.crypto.asymmetric.SM2;
import java.nio.charset.StandardCharsets;
import java.security.KeyPair;
import java.util.Base64;
import org.bouncycastle.crypto.digests.SM3Digest;
import org.bouncycastle.crypto.engines.SM2Engine;
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey;
public class Keys {
public static void main(String[] args) {
String text = "我是笨蛋";
System.out.println("原文:" + text);
SM2 sm2 = new SM2();
SM2Engine.Mode mode=SM2Engine.Mode.C1C2C3;
sm2.setMode(mode);
sm2.setDigest(new SM3Digest());
String privateKey = HexUtil.encodeHexStr(sm2.getD());
String publicKey = HexUtil.encodeHexStr(sm2.getQ(false));
System.out.println("privateKey:" + privateKey);
System.out.println("publicKey:" + publicKey);
}}
3.Python加签验签
import gmssl.func as gmssl_func
from baseutils.gmssl import sm2, func
private_key = 'b9069975c3276fab170ce5ea643e635b8f52075e69e6162232ae01555ed12a31'
public_key = '04f9ad444af8e31f993d96a644c6759ae8f9e5056068540eaa0e6d5f6b338d8ac4a7aac58170c1f18a7227c0dd72daee8b4e1e10d3db94aab6ab0fc5cac550f048'
data = 'Hello, SM2!'.encode("utf-8")
hashdata = sm2.sm3.sm3_hash(func.bytes_to_list(data))
hashdata2 = bytes.fromhex(hashdata)
print("hashdata----------", hashdata)
signer = sm2.CryptSM2(private_key=private_key, public_key=public_key) # 签名
random_hex = gmssl_func.random_hex(signer.para_len)
signature = signer.sign(hashdata2, random_hex)
verifier = sm2.CryptSM2(private_key=private_key, public_key=public_key) # 验证签名
is_valid = verifier.verify(signature, hashdata2)
print("签名数据:", data)
print("签名:", signature)
print("签名验证结果:", is_valid)
4.java加签验签
package SM;
import cn.hutool.core.lang.func.Func;
import cn.hutool.core.util.HexUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.crypto.SecureUtil;
import cn.hutool.crypto.SmUtil;
import cn.hutool.crypto.asymmetric.KeyType;
import cn.hutool.crypto.asymmetric.SM2;
import java.nio.charset.StandardCharsets;
import java.security.KeyPair;
import java.util.Base64;
/**
* 国密非对称加解密和加签验签算法
*
*/
public class Sm2Test9 {
public static void main(String[] args) {
String text = "狗";
System.out.println("原文:" + text);
KeyPair pair = SecureUtil.generateKeyPair("SM2");
byte[] privateKey = pair.getPrivate().getEncoded();
byte[] publicKey = pair.getPublic().getEncoded();
System.out.println("公钥:\n" + bytesToBase64(publicKey));
System.out.println("私钥:\n" + bytesToBase64(privateKey));
SM2 sm2 = SmUtil.sm2(privateKey, publicKey);
//加签
String sign = sm2.signHex(HexUtil.encodeHexStr(text));
System.out.println("签名:" + sign);
//验签
boolean verify = sm2.verifyHex(HexUtil.encodeHexStr(text), sign);
System.out.println("验签:" + verify);
/**
* 验签私钥
*/
String signPrivateKey = "MIGTAgEAMBMGByqGSM49AgEGCCqBHM9VAYItBHkwdwIBAQQgvIp+WfwJgOvvyPvfaxikVpRD5V5s2Z0hPo2a+GpfVzygCgYIKoEcz1UBgi2hRANCAARfv1UZ0Au40+P8bMqxCRaRx8VCc76S+UTTW2AaoO+5H+Z/XV96Dby1WulAGfOVoPdVpqb4rNcjKvrIjAujC+px";
String signPublicKey = "MFkwEwYHKoZIzj0CAQYIKoEcz1UBgi0DQgAEX79VGdALuNPj/GzKsQkWkcfFQnO+kvlE01tgGqDvuR/mf11feg28tVrpQBnzlaD3Vaam+KzXIyr6yIwLowvqcQ==";
String hexString1 = bytesToHex(base64ToBytes(signPrivateKey));
String hexString2 = bytesToHex(base64ToBytes(signPublicKey));
System.out.println("公钥00:\n" + hexString2);
System.out.println("私钥00:\n" + hexString1);
byte[] privateKey1 = base64ToBytes(signPrivateKey);
byte[] publicKey1 = base64ToBytes(signPublicKey);
SM2 sm22 = SmUtil.sm2(privateKey1, publicKey1);
//加签
String sign2 = sm22.signHex(HexUtil.encodeHexStr(text));
System.out.println("签名111111:" + sign2);
//验签
boolean verify2 = sm22.verifyHex(HexUtil.encodeHexStr(text), sign2);
System.out.println("验签11111:" + verify2);
}
public static String bytesToHex(byte[] bytes) {
StringBuilder sb = new StringBuilder();
for (byte b : bytes) {
sb.append(String.format("%02x", b & 0xff));
}
return sb.toString();
}
/**
* 字节数组转Base64编码
*
* @param bytes 字节数组
* @return Base64编码
*/
private static String bytesToBase64(byte[] bytes) {
byte[] encodedBytes = Base64.getEncoder().encode(bytes);
return new String(encodedBytes, StandardCharsets.UTF_8);
}
/**
* Base64编码转字节数组
*
* @param base64Str Base64编码
* @return 字节数组
*/
private static byte[] base64ToBytes(String base64Str) {
byte[] bytes = base64Str.getBytes(StandardCharsets.UTF_8);
return Base64.getDecoder().decode(bytes);
}
/**
* 16进制字符串编码转字节数组
*
* @param
* @return 字节数组
*/
public static byte[] hexStringToByteArray(String s) {
int len = s.length();
byte[] data = new byte[len / 2];
for (int i = 0; i < len; i += 2) {
data[i / 2] = (byte) ((Character.digit(s.charAt(i), 16) << 4)
+ Character.digit(s.charAt(i+1), 16));
}
return data;
}
}
四、遇到的问题
- python与java各自生成的密钥对加签验签没有问题,但是互相验签失败
- python gmssl版本有影响
- java hutool 生成的密钥对的方法有很大区别,hutool本身也提供了几种生成不同密钥对的方法需要注意(30开头的ASN1格式的密钥好像目前python兼容不了)
五、解决方案
- java兼容python(目前我提供的这个代码是java生成的密钥对python能用,python生成的密钥对java也能用)
- python兼容java(把java类打包成可执行jar文件,在python中执行)
import subprocess
java_program = "java -jar D:\SM2Python\python_java\JARS\SM2.jar"
result = subprocess.run(java_program.split(), stdout=subprocess.PIPE, universal_newlines=True)
print(result.stdout)
- 搭建一个java项目可接口访问
- 附上完整的SM2,SM3类文件
#sm2.py
import binascii
from Crypto.Util.asn1 import DerSequence, DerInteger
from . import sm3, func
from binascii import unhexlify
# 选择素域,设置椭圆曲线参数
default_ecc_table = {
'n': 'FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123',
'p': 'FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF',
'g': '32c4ae2c1f1981195f9904466a39c9948fe30bbff2660be1715a4589334c74c7'
'bc3736a2f4f6779c59bdcee36b692153d0a9877cc62a474002df32e52139f0a0',
'a': 'FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC',
'b': '28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93',
}
class CryptSM2(object):
def __init__(self, private_key, public_key, ecc_table=default_ecc_table, mode=0, asn1=False):
"""
mode: 0-C1C2C3, 1-C1C3C2 (default is 1)
"""
self.private_key = private_key
self.public_key = public_key.lstrip("04") if public_key.startswith("04") else public_key
self.para_len = len(ecc_table['n'])
self.ecc_a3 = (
int(ecc_table['a'], base=16) + 3) % int(ecc_table['p'], base=16)
self.ecc_table = ecc_table
assert mode in (0, 1), 'mode must be one of (0, 1)'
self.mode = mode
self.asn1 = asn1
def _kg(self, k, Point): # kP运算
Point = '%s%s' % (Point, '1')
mask_str = '8'
for i in range(self.para_len - 1):
mask_str += '0'
mask = int(mask_str, 16)
Temp = Point
flag = False
for n in range(self.para_len * 4):
if (flag):
Temp = self._double_point(Temp)
if (k & mask) != 0:
if (flag):
Temp = self._add_point(Temp, Point)
else:
flag = True
Temp = Point
k = k << 1
return self._convert_jacb_to_nor(Temp)
def _double_point(self, Point): # 倍点
l = len(Point)
len_2 = 2 * self.para_len
if l < self.para_len * 2:
return None
else:
x1 = int(Point[0:self.para_len], 16)
y1 = int(Point[self.para_len:len_2], 16)
if l == len_2:
z1 = 1
else:
z1 = int(Point[len_2:], 16)
T6 = (z1 * z1) % int(self.ecc_table['p'], base=16)
T2 = (y1 * y1) % int(self.ecc_table['p'], base=16)
T3 = (x1 + T6) % int(self.ecc_table['p'], base=16)
T4 = (x1 - T6) % int(self.ecc_table['p'], base=16)
T1 = (T3 * T4) % int(self.ecc_table['p'], base=16)
T3 = (y1 * z1) % int(self.ecc_table['p'], base=16)
T4 = (T2 * 8) % int(self.ecc_table['p'], base=16)
T5 = (x1 * T4) % int(self.ecc_table['p'], base=16)
T1 = (T1 * 3) % int(self.ecc_table['p'], base=16)
T6 = (T6 * T6) % int(self.ecc_table['p'], base=16)
T6 = (self.ecc_a3 * T6) % int(self.ecc_table['p'], base=16)
T1 = (T1 + T6) % int(self.ecc_table['p'], base=16)
z3 = (T3 + T3) % int(self.ecc_table['p'], base=16)
T3 = (T1 * T1) % int(self.ecc_table['p'], base=16)
T2 = (T2 * T4) % int(self.ecc_table['p'], base=16)
x3 = (T3 - T5) % int(self.ecc_table['p'], base=16)
if (T5 % 2) == 1:
T4 = (T5 + ((T5 + int(self.ecc_table['p'], base=16)) >> 1) - T3) % int(
self.ecc_table['p'], base=16)
else:
T4 = (T5 + (T5 >> 1) - T3) % int(self.ecc_table['p'], base=16)
T1 = (T1 * T4) % int(self.ecc_table['p'], base=16)
y3 = (T1 - T2) % int(self.ecc_table['p'], base=16)
form = '%%0%dx' % self.para_len
form = form * 3
return form % (x3, y3, z3)
def _add_point(self, P1, P2): # 点加函数,P2点为仿射坐标即z=1,P1为Jacobian加重射影坐标
len_2 = 2 * self.para_len
l1 = len(P1)
l2 = len(P2)
if (l1 < len_2) or (l2 < len_2):
return None
else:
X1 = int(P1[0:self.para_len], 16)
Y1 = int(P1[self.para_len:len_2], 16)
if (l1 == len_2):
Z1 = 1
else:
Z1 = int(P1[len_2:], 16)
x2 = int(P2[0:self.para_len], 16)
y2 = int(P2[self.para_len:len_2], 16)
T1 = (Z1 * Z1) % int(self.ecc_table['p'], base=16)
T2 = (y2 * Z1) % int(self.ecc_table['p'], base=16)
T3 = (x2 * T1) % int(self.ecc_table['p'], base=16)
T1 = (T1 * T2) % int(self.ecc_table['p'], base=16)
T2 = (T3 - X1) % int(self.ecc_table['p'], base=16)
T3 = (T3 + X1) % int(self.ecc_table['p'], base=16)
T4 = (T2 * T2) % int(self.ecc_table['p'], base=16)
T1 = (T1 - Y1) % int(self.ecc_table['p'], base=16)
Z3 = (Z1 * T2) % int(self.ecc_table['p'], base=16)
T2 = (T2 * T4) % int(self.ecc_table['p'], base=16)
T3 = (T3 * T4) % int(self.ecc_table['p'], base=16)
T5 = (T1 * T1) % int(self.ecc_table['p'], base=16)
T4 = (X1 * T4) % int(self.ecc_table['p'], base=16)
X3 = (T5 - T3) % int(self.ecc_table['p'], base=16)
T2 = (Y1 * T2) % int(self.ecc_table['p'], base=16)
T3 = (T4 - X3) % int(self.ecc_table['p'], base=16)
T1 = (T1 * T3) % int(self.ecc_table['p'], base=16)
Y3 = (T1 - T2) % int(self.ecc_table['p'], base=16)
form = '%%0%dx' % self.para_len
form = form * 3
return form % (X3, Y3, Z3)
def _convert_jacb_to_nor(self, Point): # Jacobian加重射影坐标转换成仿射坐标
len_2 = 2 * self.para_len
x = int(Point[0:self.para_len], 16)
y = int(Point[self.para_len:len_2], 16)
z = int(Point[len_2:], 16)
z_inv = pow(
z, int(self.ecc_table['p'], base=16) - 2, int(self.ecc_table['p'], base=16))
z_invSquar = (z_inv * z_inv) % int(self.ecc_table['p'], base=16)
z_invQube = (z_invSquar * z_inv) % int(self.ecc_table['p'], base=16)
x_new = (x * z_invSquar) % int(self.ecc_table['p'], base=16)
y_new = (y * z_invQube) % int(self.ecc_table['p'], base=16)
z_new = (z * z_inv) % int(self.ecc_table['p'], base=16)
if z_new == 1:
form = '%%0%dx' % self.para_len
form = form * 2
return form % (x_new, y_new)
else:
return None
def verify(self, Sign, data):
# 验签函数,sign签名r||s,E消息hash,public_key公钥
if self.asn1:
unhex_sign = unhexlify(Sign.encode())
seq_der = DerSequence()
origin_sign = seq_der.decode(unhex_sign)
r = origin_sign[0]
s = origin_sign[1]
else:
r = int(Sign[0:self.para_len], 16)
s = int(Sign[self.para_len:2*self.para_len], 16)
e = int(data.hex(), 16)
t = (r + s) % int(self.ecc_table['n'], base=16)
if t == 0:
return 0
P1 = self._kg(s, self.ecc_table['g'])
P2 = self._kg(t, self.public_key)
# print(P1)
# print(P2)
if P1 == P2:
P1 = '%s%s' % (P1, 1)
P1 = self._double_point(P1)
else:
P1 = '%s%s' % (P1, 1)
P1 = self._add_point(P1, P2)
P1 = self._convert_jacb_to_nor(P1)
x = int(P1[0:self.para_len], 16)
return r == ((e + x) % int(self.ecc_table['n'], base=16))
def sign(self, data, K):
"""
签名函数, data消息的hash,private_key私钥,K随机数,均为16进制字符串
:param self:
:param data: data消息的hash
:param K: K随机数
:return:
"""
E = data.hex() # 消息转化为16进制字符串
e = int(E, 16)
d = int(self.private_key, 16)
k = int(K, 16)
P1 = self._kg(k, self.ecc_table['g'])
x = int(P1[0:self.para_len], 16)
R = ((e + x) % int(self.ecc_table['n'], base=16))
if R == 0 or R + k == int(self.ecc_table['n'], base=16):
return None
d_1 = pow(
d+1, int(self.ecc_table['n'], base=16) - 2, int(self.ecc_table['n'], base=16))
S = (d_1*(k + R) - R) % int(self.ecc_table['n'], base=16)
if S == 0:
return None
elif self.asn1:
return DerSequence([DerInteger(R), DerInteger(S)]).encode().hex()
else:
return '%064x%064x' % (R, S)
def encrypt(self, data):
# 加密函数,data消息(bytes)
msg = data.hex() # 消息转化为16进制字符串
k = func.random_hex(self.para_len)
C1 = self._kg(int(k, 16), self.ecc_table['g'])
xy = self._kg(int(k, 16), self.public_key)
x2 = xy[0:self.para_len]
y2 = xy[self.para_len:2*self.para_len]
ml = len(msg)
t = sm3.sm3_kdf(xy.encode('utf8'), ml / 2)
if int(t, 16) == 0:
return None
else:
form = '%%0%dx' % ml
C2 = form % (int(msg, 16) ^ int(t, 16))
C3 = sm3.sm3_hash([
i for i in bytes.fromhex('%s%s%s' % (x2, msg, y2))
])
if self.mode:
return bytes.fromhex('%s%s%s' % (C1, C3, C2))
else:
return bytes.fromhex('%s%s%s' % (C1, C2, C3))
def decrypt(self, data):
# 解密函数,data密文(bytes)
data = data.hex()
len_2 = 2 * self.para_len
len_3 = len_2 + 64
C1 = data[0:len_2]
if self.mode:
C3 = data[len_2:len_3]
C2 = data[len_3:]
else:
C2 = data[len_2:-64]
C3 = data[-64:]
xy = self._kg(int(self.private_key, 16), C1)
# print('xy = %s' % xy)
x2 = xy[0:self.para_len]
y2 = xy[self.para_len:len_2]
cl = len(C2)
t = sm3.sm3_kdf(xy.encode('utf8'), cl / 2)
if int(t, 16) == 0:
return None
else:
form = '%%0%dx' % cl
M = form % (int(C2, 16) ^ int(t, 16))
u = sm3.sm3_hash([
i for i in bytes.fromhex('%s%s%s' % (x2, M, y2))
])
return bytes.fromhex(M)
def _sm3_z(self, data, verify_public=None):
"""
SM3WITHSM2 签名规则: SM2.sign(SM3(Z+MSG),PrivateKey)
其中: z = Hash256(Len(ID) + ID + a + b + xG + yG + xA + yA)
"""
# sm3withsm2 的 z 值(这里有self.public_key,所以必须带上自己的公钥,即验签的时候所用的,不然验签就会失败,感觉有点大坑啊)
sign_verify_key = self.public_key
if verify_public is not None:
sign_verify_key = (verify_public if len(verify_public) <= 128 else verify_public[2:])
z = '0080'+'31323334353637383132333435363738' + \
self.ecc_table['a'] + self.ecc_table['b'] + self.ecc_table['g'] + \
sign_verify_key
z = binascii.a2b_hex(z)
Za = sm3.sm3_hash(func.bytes_to_list(z))
M_ = (Za + data.hex()).encode('utf-8')
e = sm3.sm3_hash(func.bytes_to_list(binascii.a2b_hex(M_)))
return e
def sign_with_sm3(self, data, random_hex_str=None, verify_public = None):
sign_data = binascii.a2b_hex(self._sm3_z(data, verify_public).encode('utf-8'))
if random_hex_str is None:
random_hex_str = func.random_hex(self.para_len)
sign = self.sign(sign_data, random_hex_str) # 16进制
return sign
def verify_with_sm3(self, sign, data):
sign_data = binascii.a2b_hex(self._sm3_z(data).encode('utf-8'))
return self.verify(sign, sign_data)
#sm3.py
import binascii
from math import ceil
from .func import rotl, bytes_to_list
IV = [
1937774191, 1226093241, 388252375, 3666478592,
2842636476, 372324522, 3817729613, 2969243214,
]
T_j = [
2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169,
2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169,
2043430169, 2043430169, 2043430169, 2043430169, 2055708042, 2055708042,
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
2055708042, 2055708042, 2055708042, 2055708042
]
def sm3_ff_j(x, y, z, j):
if 0 <= j and j < 16:
ret = x ^ y ^ z
elif 16 <= j and j < 64:
ret = (x & y) | (x & z) | (y & z)
return ret
def sm3_gg_j(x, y, z, j):
if 0 <= j and j < 16:
ret = x ^ y ^ z
elif 16 <= j and j < 64:
#ret = (X | Y) & ((2 ** 32 - 1 - X) | Z)
ret = (x & y) | ((~ x) & z)
return ret
def sm3_p_0(x):
return x ^ (rotl(x, 9 % 32)) ^ (rotl(x, 17 % 32))
def sm3_p_1(x):
return x ^ (rotl(x, 15 % 32)) ^ (rotl(x, 23 % 32))
def sm3_cf(v_i, b_i):
w = []
for i in range(16):
weight = 0x1000000
data = 0
for k in range(i*4,(i+1)*4):
data = data + b_i[k]*weight
weight = int(weight/0x100)
w.append(data)
for j in range(16, 68):
w.append(0)
w[j] = sm3_p_1(w[j-16] ^ w[j-9] ^ (rotl(w[j-3], 15 % 32))) ^ (rotl(w[j-13], 7 % 32)) ^ w[j-6]
str1 = "%08x" % w[j]
w_1 = []
for j in range(0, 64):
w_1.append(0)
w_1[j] = w[j] ^ w[j+4]
str1 = "%08x" % w_1[j]
a, b, c, d, e, f, g, h = v_i
for j in range(0, 64):
ss_1 = rotl(
((rotl(a, 12 % 32)) +
e +
(rotl(T_j[j], j % 32))) & 0xffffffff, 7 % 32
)
ss_2 = ss_1 ^ (rotl(a, 12 % 32))
tt_1 = (sm3_ff_j(a, b, c, j) + d + ss_2 + w_1[j]) & 0xffffffff
tt_2 = (sm3_gg_j(e, f, g, j) + h + ss_1 + w[j]) & 0xffffffff
d = c
c = rotl(b, 9 % 32)
b = a
a = tt_1
h = g
g = rotl(f, 19 % 32)
f = e
e = sm3_p_0(tt_2)
a, b, c, d, e, f, g, h = map(
lambda x:x & 0xFFFFFFFF ,[a, b, c, d, e, f, g, h])
v_j = [a, b, c, d, e, f, g, h]
return [v_j[i] ^ v_i[i] for i in range(8)]
def sm3_hash(msg):
# print(msg)
len1 = len(msg)
reserve1 = len1 % 64
msg.append(0x80)
reserve1 = reserve1 + 1
# 56-64, add 64 byte
range_end = 56
if reserve1 > range_end:
range_end = range_end + 64
for i in range(reserve1, range_end):
msg.append(0x00)
bit_length = (len1) * 8
bit_length_str = [bit_length % 0x100]
for i in range(7):
bit_length = int(bit_length / 0x100)
bit_length_str.append(bit_length % 0x100)
for i in range(8):
msg.append(bit_length_str[7-i])
group_count = round(len(msg) / 64)
B = []
for i in range(0, group_count):
B.append(msg[i*64:(i+1)*64])
V = []
V.append(IV)
for i in range(0, group_count):
V.append(sm3_cf(V[i], B[i]))
y = V[i+1]
result = ""
for i in y:
result = '%s%08x' % (result, i)
return result
def sm3_kdf(z, klen): # z为16进制表示的比特串(str),klen为密钥长度(单位byte)
klen = int(klen)
ct = 0x00000001
rcnt = ceil(klen/32)
zin = [i for i in bytes.fromhex(z.decode('utf8'))]
ha = ""
for i in range(rcnt):
msg = zin + [i for i in binascii.a2b_hex(('%08x' % ct).encode('utf8'))]
ha = ha + sm3_hash(msg)
ct += 1
return ha[0: klen * 2]
#func.py
from random import choice
xor = lambda a, b:list(map(lambda x, y: x ^ y, a, b))
rotl = lambda x, n:((x << n) & 0xffffffff) | ((x >> (32 - n)) & 0xffffffff)
get_uint32_be = lambda key_data:((key_data[0] << 24) | (key_data[1] << 16) | (key_data[2] << 8) | (key_data[3]))
put_uint32_be = lambda n:[((n>>24)&0xff), ((n>>16)&0xff), ((n>>8)&0xff), ((n)&0xff)]
padding = lambda data, block=16: data + [(16 - len(data) % block)for _ in range(16 - len(data) % block)]
unpadding = lambda data: data[:-data[-1]]
list_to_bytes = lambda data: b''.join([bytes((i,)) for i in data])
bytes_to_list = lambda data: [i for i in data]
random_hex = lambda x: ''.join([choice('0123456789abcdef') for _ in range(x)])