密钥协商由密钥系统相同的两方进行,一方为发起方A(initiator),另一方为响应方B(responder)。
A的私钥为dA,公钥为PA(xA, yA),
ZA的计算方法,在前面的章节中已经给出了计算方法gm_sm2_compute_z_digest
,可以去回顾一下。
同样的,B的私钥为dB,公钥为PB(xB, yB)以及ZB。
# 协商过程
密钥协商的过程如下:
记
用户A:
A1:生成一对随机SM2密钥对,私钥为
A2:计算
A3:计算
用户B:
B1:生成一对随机SM2密钥对,私钥为
B2:计算
B3:计算
B4:计算
B5:计算椭圆曲线点
B6:计算
B7:将RA的坐标x1、y1和RB的坐标x2、y2的数据类型转换为比特串,计算
B8:将
用户A:
A4:计算
A5:计算椭圆曲线点
A6:计算
A7:将RA的坐标x1、y1和RB的坐标x2、y2的数据类型转换为比特串,计算
A8:计算
用户B:
B9:计算
# 分阶段实现
接下来分阶段来实现,把协商过程分为三个阶段init、calculate、checkhash。
init: A1-A3,B1-B3,初始化阶段,主要是生成临时密钥对,计算过程数据,然后发送给对方
calculate:B4-B7, A4-A8接收到init过程数据后,计算出协商密钥,哈希(杂凑值)。
checkhash:校验哈希是否一致,一致则协商成功,不一致则协商失败。
# init
初始化阶段,密钥对生成之前章节中已经有相应的实现了,拿过来用就行。
来看一下A2的计算,
这里的w = 127,那么2^w就是:
// 00000000 00000000 00000000 00000000 80000000 00000000 00000000 00000000
static const gm_bn_t GM_BN_2W = {
0x00000000, 0x00000000, 0x00000000, 0x80000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000
};
2
3
4
5
2^w - 1就是:
// 00000000 00000000 00000000 00000000 7FFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF
static const gm_bn_t GM_BN_2W_SUB_ONE = {
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x7FFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000
};
2
3
4
5
有了预计算的值,完成的
// 2^w + ( x & ( 2^w − 1 ) )
static void gm_sm2_exch_reduce(gm_bn_t x) {
int i;
int num = GM_BN_ARR_SIZE / 2;
for(i = 0; i < GM_BN_ARR_SIZE; i++) {
if(i < num) {
x[i] &= GM_BN_2W_SUB_ONE[i];
x[i] += GM_BN_2W[i];
}else {
x[i] = 0;
}
}
}
2
3
4
5
6
7
8
9
10
11
12
13
这里为何能写的这么简单,不考虑加法溢出吗?
大家可以看一下,x先与GM_BN_2W_SUB_ONE相与,那么计算后,结果的最大值就是GM_BN_2W_SUB_ONE,那GM_BN_2W_SUB_ONE与GM_BN_2W相加,是不可能有溢出的。
我们来看一下完整的代码实现:
typedef struct {
gm_bn_t t; // tA or tB
unsigned char xy[64]; // 临时公钥(x1, y1) or (x2, y2)
unsigned char z[32]; // ZA or ZB
unsigned char isInitiator; // 1为发起方,否则为响应方
} gm_sm2_exch_context;
/**
* 密钥协商初始化
* @param ctx 上下文
* @param private_key 用户私钥rA or rB
* @param public_key 用户公司RA or RB
* @param isInitiator 1为发起方,否则为响应方
* @param output 输出 RA or RB
*/
void gm_sm2_exch_init(gm_sm2_exch_context * ctx, gm_bn_t private_key, const gm_point_t * public_key,
unsigned char isInitiator, const unsigned char * id_bytes, unsigned int idLen, unsigned char output[64]) {
gm_bn_t r;
gm_point_t pr;
// 生成临时密钥对
gm_sm2_gen_keypair(r, &pr);
gm_sm2_exch_init_for_test(ctx, private_key, public_key, r, &pr, isInitiator, id_bytes, idLen, output);
}
void gm_sm2_exch_init_for_test(gm_sm2_exch_context * ctx, gm_bn_t private_key, const gm_point_t * public_key,
gm_bn_t tmp_private_key, const gm_point_t * tmp_public_key,
unsigned char isInitiator, const unsigned char * id_bytes, unsigned int idLen, unsigned char output[64]) {
gm_bn_t r, x, y;
gm_point_t pr;
gm_bn_copy(r, tmp_private_key);
gm_point_copy(&pr, tmp_public_key);
gm_point_get_xy(&pr, x, y);
// 2^w + ( x & ( 2^w − 1 ) )
gm_sm2_exch_reduce(x);
// t = (d + x · r) mod n
gm_bn_to_mont(x, x, GM_BN_N);
gm_bn_to_mont(r, r, GM_BN_N);
// x * r
gm_bn_mont_mul(r, r, x, GM_BN_N);
// d + x * r
gm_bn_from_mont(r, r, GM_BN_N);
gm_bn_add(ctx->t, r, private_key, GM_BN_N);
// compute z digest
gm_sm2_compute_z_digest(id_bytes, idLen, public_key, ctx->z);
ctx->isInitiator = isInitiator;
gm_point_to_bytes(&pr, ctx->xy);
// output R
memcpy(output, ctx->xy, 64);
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# calculate
计算密钥K,用于校验的Hash值
/**
* 计算密钥K,S1/SB、S2/SA
* @param ctx 上下文
* @param peer_p 对方公钥P
* @param peer_r 对方初始化信息 R,即随机公钥
* @param id_bytes 对方user id
* @param idLen 对方user id 长度
* @param kLen 协商的密钥长度(单位字节)
* @param output 输出密钥 k || S1/SB || S2/SA,长度为kLen + 64
*/
void gm_sm2_exch_calculate(gm_sm2_exch_context * ctx, const unsigned char * peer_p, const unsigned char * peer_r,
const unsigned char * id_bytes, unsigned int idLen, int kLen, unsigned char * output) {
unsigned char buf[100] = {0};
unsigned char peerZ[32] = {0};
int i, ki, kn, ct;
gm_bn_t peerTmpX;
gm_point_t peerPubK, peerTmpPubK;
gm_sm3_context sm3_ctx;
gm_bn_from_bytes(peerTmpX, peer_r);
gm_point_from_bytes(&peerPubK, peer_p);
gm_point_from_bytes(&peerTmpPubK, peer_r);
// compute peer z digest
gm_sm2_compute_z_digest(id_bytes, idLen, &peerPubK, peerZ);
// 2^w + ( peerTmpX & ( 2^w − 1 ) )
gm_sm2_exch_reduce(peerTmpX);
// U = t * (peerPubK + peerTmpX · peerTmpPubK)
gm_point_mul(&peerTmpPubK, peerTmpX, &peerTmpPubK);
gm_point_add(&peerPubK, &peerPubK, &peerTmpPubK);
gm_point_mul(&peerPubK, ctx->t, &peerPubK);
gm_point_to_bytes(&peerPubK, buf);
// KDF(x_u || y_u || Z_A || Z_B)
kn = (kLen + 31) / 32;
ki = 0;
for(i = 0, ct = 1; i < kn; i++, ct++) {
gm_sm3_init(&sm3_ctx);
gm_sm3_update(&sm3_ctx, buf, 64);
if(ctx->isInitiator) {
gm_sm3_update(&sm3_ctx, ctx->z, 32);
gm_sm3_update(&sm3_ctx, peerZ, 32);
}else {
gm_sm3_update(&sm3_ctx, peerZ, 32);
gm_sm3_update(&sm3_ctx, ctx->z, 32);
}
GM_PUT_UINT32_BE(ct, buf + 64, 0);
gm_sm3_update(&sm3_ctx, buf + 64, 4);
gm_sm3_done(&sm3_ctx, buf + 68);
// output kA or kB
if(i == (kn - 1)) {
memcpy(output + ki, buf + 68, (kLen - ki));
}else {
memcpy(output + ki, buf + 68, 32);
ki += 32;
}
}
// Hash(0x02 || y_u || Hash(x_u || Z_A || Z_B || x_1 || y_1 || x_2 || y_2))
gm_sm3_init(&sm3_ctx);
gm_sm3_update(&sm3_ctx, buf, 32);
if(ctx->isInitiator) {
gm_sm3_update(&sm3_ctx, ctx->z, 32);
gm_sm3_update(&sm3_ctx, peerZ, 32);
gm_sm3_update(&sm3_ctx, ctx->xy, 64);
gm_sm3_update(&sm3_ctx, peer_r, 64);
}else {
gm_sm3_update(&sm3_ctx, peerZ, 32);
gm_sm3_update(&sm3_ctx, ctx->z, 32);
gm_sm3_update(&sm3_ctx, peer_r, 64);
gm_sm3_update(&sm3_ctx, ctx->xy, 64);
}
gm_sm3_done(&sm3_ctx, buf + 68);
gm_sm3_init(&sm3_ctx);
buf[31] = 0x02;
gm_sm3_update(&sm3_ctx, buf + 31, 33);
gm_sm3_update(&sm3_ctx, buf + 68, 32);
// ouput s1 or sB
gm_sm3_done(&sm3_ctx, output + kLen);
// Hash(0x03 || y_u || Hash(x_u || Z_A || Z_B || x_1 || y_1 || x_2 || y_2))
gm_sm3_init(&sm3_ctx);
buf[31] = 0x03;
gm_sm3_update(&sm3_ctx, buf + 31, 33);
gm_sm3_update(&sm3_ctx, buf + 68, 32);
// ouput s2 or sA
gm_sm3_done(&sm3_ctx, output + kLen + 32);
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# checkhash
hash在calculate阶段已经计算完并且输出到output了,由业务层进行校验即可。
# 单元测试
测试代码:
void test_sm2_key_exch() {
gm_bn_t da, tmpda;
gm_bn_t db, tmpdb;
gm_point_t pa, tmppa;
gm_point_t pb, tmppb;
unsigned char userId_a[3] = {0x61, 0x62, 0x63};
unsigned char userId_b[5] = {0x61, 0x62, 0x63, 0x64, 0x65};
unsigned char rp_a[128] = {0};
unsigned char rp_b[128] = {0};
unsigned char k_s1_sa[256] = {0};
unsigned char k_sb_s2[256] = {0};
gm_bn_from_hex(da, "1ED44070B763431D23D35A227A34D91558DC0B1EDD87E91238D4A54D98FAB6A0");
gm_bn_from_hex(tmpda, "B45B1F0577C6D37C86F252B394B20E55FEEEF2DEE49743A68EC7871CECD89872");
gm_bn_from_hex(db, "D18FE8EFD4E7C5B2FFDC356E16E397D2443DB6EA4C453EB5DC2852F8E301E846");
gm_bn_from_hex(tmpdb, "37ED4CE7C7951B76BE93CFD116A9F8AE439664107A59278E0F7095B964A8C7BA");
gm_point_mul(&pa, da, GM_MONT_G);
gm_point_mul(&tmppa, tmpda, GM_MONT_G);
gm_point_mul(&pb, db, GM_MONT_G);
gm_point_mul(&tmppb, tmpdb, GM_MONT_G);
gm_point_to_bytes(&pa, rp_a + 64);
gm_point_to_bytes(&pb, rp_b + 64);
gm_sm2_exch_context exa, exb;
unsigned char expbuf[116] = {0};
gm_hex2bin("3A18CB6BE2DC15C49998BE75DA28C4DEB3ADF33E08E886FCD7B2869CD006A6C4D5852D9E194A091EC9AC01B2D6B5153A09CA39BC3FB4984A09E4CE5B0DEC0E105CA12D712F6C8CBE59BFE54CAD0641B922D3EB0AD10C1D2347BA10985624ACC5A4C21400A3441D8EA5DE97B897B2635E6AEDE9",
230, expbuf);
int i;
for(i = 0; i < 100; i++) {
// A为发起方,发起方初始化
gm_sm2_exch_init_for_test(&exa, da, &pa, tmpda, &tmppa, 1, userId_a, 3, rp_a);
// B为响应方
gm_sm2_exch_init_for_test(&exb, db, &pb, tmpdb, &tmppb, 0, userId_b, 5, rp_b);
// B拿到A的r z w进行密钥计算
gm_sm2_exch_calculate(&exb, rp_a + 64, rp_a, userId_a, 3, 16 + i, k_sb_s2);
// A拿到B的r z w进行密钥计算
gm_sm2_exch_calculate(&exa, rp_b + 64, rp_b, userId_b, 5, 16 + i, k_s1_sa);
// A校验s1 == sb
if(memcmp(k_s1_sa + 16, k_sb_s2 + 16, 32) != 0) {
printf("test result s1 == sb: fail\n");
return;
}
// B校验s2 == sa
if(memcmp(k_s1_sa + 16 + 32, k_sb_s2 + 16 + 32, 32) != 0) {
printf("test result s2 == sa: fail\n");
return;
}
// 最后来看看两方计算的密钥值是否一致
if(memcmp(k_s1_sa, k_sb_s2, 16 + i) != 0) {
printf("test result ka == kb: fail\n");
return;
}
if(memcmp(k_s1_sa, expbuf, 16 + i) != 0) {
printf("test result check k: fail\n");
return;
}
}
printf("test result: ok\n");
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
main函数增加:
if(strcmp(argv[1], "sm2_key_exch") == 0) {
test_sm2_key_exch();
}
2
3
执行结果:
192:c saint$ time ./gm_test sm2_key_exch
test result: ok
real 0m2.206s
user 0m2.118s
sys 0m0.023s
2
3
4
5
6