密钥协商由密钥系统相同的两方进行,一方为发起方A(initiator),另一方为响应方B(responder)。

A的私钥为dA,公钥为PA(xA, yA),

ZA的计算方法,在前面的章节中已经给出了计算方法gm_sm2_compute_z_digest,可以去回顾一下。

同样的,B的私钥为dB,公钥为PB(xB, yB)以及ZB。

# 协商过程

密钥协商的过程如下:

用户A:
A1:生成一对随机SM2密钥对,私钥为,公钥为
A2:计算
A3:计算 ,将发送给B

用户B:
B1:生成一对随机SM2密钥对,私钥为,公钥为
B2:计算
B3:计算

B4:计算
B5:计算椭圆曲线点
B6:计算
B7:将RA的坐标x1、y1和RB的坐标x2、y2的数据类型转换为比特串,计算
B8:将发送给A

用户A:
A4:计算
A5:计算椭圆曲线点
A6:计算
A7:将RA的坐标x1、y1和RB的坐标x2、y2的数据类型转换为比特串,计算,并检验是否成立,若等式不成立则密钥协商失败
A8:计算 A9:并将SA发送给用户B

用户B:
B9:计算,并检验是否成立,若等式不成立则密钥协商失败

# 分阶段实现

接下来分阶段来实现,把协商过程分为三个阶段init、calculate、checkhash。

init: A1-A3,B1-B3,初始化阶段,主要是生成临时密钥对,计算过程数据,然后发送给对方
calculate:B4-B7, A4-A8接收到init过程数据后,计算出协商密钥,哈希(杂凑值)。
checkhash:校验哈希是否一致,一致则协商成功,不一致则协商失败。

# init

初始化阶段,密钥对生成之前章节中已经有相应的实现了,拿过来用就行。

来看一下A2的计算,

这里的w = 127,那么2^w就是:

// 00000000 00000000 00000000 00000000 80000000 00000000 00000000 00000000
static const gm_bn_t GM_BN_2W = {
    0x00000000, 0x00000000, 0x00000000, 0x80000000,
    0x00000000, 0x00000000, 0x00000000, 0x00000000
};
1
2
3
4
5

2^w - 1就是:

// 00000000 00000000 00000000 00000000 7FFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF
static const gm_bn_t GM_BN_2W_SUB_ONE = {
    0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x7FFFFFFF,
    0x00000000, 0x00000000, 0x00000000, 0x00000000
};
1
2
3
4
5

有了预计算的值,完成的实现就简单了

// 2^w + ( x & ( 2^w − 1 ) )
static void gm_sm2_exch_reduce(gm_bn_t x) {
    int i;
    int num = GM_BN_ARR_SIZE / 2;
    for(i = 0; i < GM_BN_ARR_SIZE; i++) {
        if(i < num) {
            x[i] &= GM_BN_2W_SUB_ONE[i];
            x[i] += GM_BN_2W[i];
        }else {
            x[i] = 0;
        }
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13

这里为何能写的这么简单,不考虑加法溢出吗?

大家可以看一下,x先与GM_BN_2W_SUB_ONE相与,那么计算后,结果的最大值就是GM_BN_2W_SUB_ONE,那GM_BN_2W_SUB_ONE与GM_BN_2W相加,是不可能有溢出的。

我们来看一下完整的代码实现:

typedef struct {
    gm_bn_t t;                      // tA or tB
    unsigned char xy[64];           // 临时公钥(x1, y1) or (x2, y2)
    unsigned char z[32];            // ZA or ZB
    unsigned char isInitiator;      // 1为发起方,否则为响应方
} gm_sm2_exch_context;

/**
 * 密钥协商初始化
 * @param ctx 上下文
 * @param private_key 用户私钥rA or rB
 * @param public_key 用户公司RA or RB
 * @param isInitiator 1为发起方,否则为响应方
 * @param output 输出 RA or RB
 */
void gm_sm2_exch_init(gm_sm2_exch_context * ctx, gm_bn_t private_key, const gm_point_t * public_key, 
  unsigned char isInitiator, const unsigned char * id_bytes, unsigned int idLen, unsigned char output[64]) {
    gm_bn_t r;
    gm_point_t pr;

    // 生成临时密钥对
    gm_sm2_gen_keypair(r, &pr);

    gm_sm2_exch_init_for_test(ctx, private_key, public_key, r, &pr, isInitiator, id_bytes, idLen, output);
}

void gm_sm2_exch_init_for_test(gm_sm2_exch_context * ctx, gm_bn_t private_key, const gm_point_t * public_key, 
  gm_bn_t tmp_private_key, const gm_point_t * tmp_public_key, 
  unsigned char isInitiator, const unsigned char * id_bytes, unsigned int idLen, unsigned char output[64]) {
    gm_bn_t r, x, y;
    gm_point_t pr;

    gm_bn_copy(r, tmp_private_key);
    gm_point_copy(&pr, tmp_public_key);

    gm_point_get_xy(&pr, x, y);

    // 2^w + ( x & ( 2^w − 1 ) )
    gm_sm2_exch_reduce(x);

    // t = (d + x · r) mod n
    gm_bn_to_mont(x, x, GM_BN_N);
    gm_bn_to_mont(r, r, GM_BN_N);
    // x * r
    gm_bn_mont_mul(r, r, x, GM_BN_N);
    // d + x * r
    gm_bn_from_mont(r, r, GM_BN_N);
    gm_bn_add(ctx->t, r, private_key, GM_BN_N);

    // compute z digest
    gm_sm2_compute_z_digest(id_bytes, idLen, public_key, ctx->z);

    ctx->isInitiator = isInitiator;
    gm_point_to_bytes(&pr, ctx->xy);

    // output R
    memcpy(output, ctx->xy, 64);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

# calculate

计算密钥K,用于校验的Hash值

/**
 * 计算密钥K,S1/SB、S2/SA
 * @param ctx 上下文
 * @param peer_p 对方公钥P
 * @param peer_r 对方初始化信息 R,即随机公钥
 * @param id_bytes 对方user id
 * @param idLen 对方user id 长度
 * @param kLen 协商的密钥长度(单位字节)
 * @param output 输出密钥 k || S1/SB || S2/SA,长度为kLen + 64
 */
void gm_sm2_exch_calculate(gm_sm2_exch_context * ctx, const unsigned char * peer_p, const unsigned char * peer_r, 
  const unsigned char * id_bytes, unsigned int idLen, int kLen, unsigned char * output) {
    unsigned char buf[100] = {0};
    unsigned char peerZ[32] = {0};
    int i, ki, kn, ct;
    
    gm_bn_t peerTmpX;
    gm_point_t peerPubK, peerTmpPubK;
    gm_sm3_context sm3_ctx;

    gm_bn_from_bytes(peerTmpX, peer_r);
    gm_point_from_bytes(&peerPubK, peer_p);
    gm_point_from_bytes(&peerTmpPubK, peer_r);

    // compute peer z digest
    gm_sm2_compute_z_digest(id_bytes, idLen, &peerPubK, peerZ);

    // 2^w + ( peerTmpX & ( 2^w − 1 ) )
    gm_sm2_exch_reduce(peerTmpX);

    // U = t * (peerPubK + peerTmpX · peerTmpPubK)
    gm_point_mul(&peerTmpPubK, peerTmpX, &peerTmpPubK);
    gm_point_add(&peerPubK, &peerPubK, &peerTmpPubK);

    
    gm_point_mul(&peerPubK, ctx->t, &peerPubK);
    gm_point_to_bytes(&peerPubK, buf);

    // KDF(x_u || y_u || Z_A || Z_B)
    kn = (kLen + 31) / 32;
    ki = 0;
    for(i = 0, ct = 1; i < kn; i++, ct++) {
        
        gm_sm3_init(&sm3_ctx);
        gm_sm3_update(&sm3_ctx, buf, 64);

        if(ctx->isInitiator) {
            gm_sm3_update(&sm3_ctx, ctx->z, 32);
            gm_sm3_update(&sm3_ctx, peerZ, 32);
        }else {
            gm_sm3_update(&sm3_ctx, peerZ, 32);
            gm_sm3_update(&sm3_ctx, ctx->z, 32);
        }

        GM_PUT_UINT32_BE(ct, buf + 64, 0);
        gm_sm3_update(&sm3_ctx, buf + 64, 4);
        gm_sm3_done(&sm3_ctx, buf + 68);

        // output kA or kB
        if(i == (kn - 1)) {
            memcpy(output + ki, buf + 68, (kLen - ki));
        }else {
            memcpy(output + ki, buf + 68, 32);
            ki += 32;
        }
    }

    // Hash(0x02 || y_u || Hash(x_u || Z_A || Z_B || x_1 || y_1 || x_2 || y_2))
    gm_sm3_init(&sm3_ctx);
    gm_sm3_update(&sm3_ctx, buf, 32);
    if(ctx->isInitiator) {
        gm_sm3_update(&sm3_ctx, ctx->z, 32);
        gm_sm3_update(&sm3_ctx, peerZ, 32);
        gm_sm3_update(&sm3_ctx, ctx->xy, 64);
        gm_sm3_update(&sm3_ctx, peer_r, 64);
    }else {
        gm_sm3_update(&sm3_ctx, peerZ, 32);
        gm_sm3_update(&sm3_ctx, ctx->z, 32);
        gm_sm3_update(&sm3_ctx, peer_r, 64);
        gm_sm3_update(&sm3_ctx, ctx->xy, 64);
    }
    gm_sm3_done(&sm3_ctx, buf + 68);

    gm_sm3_init(&sm3_ctx);
    buf[31] = 0x02;
    gm_sm3_update(&sm3_ctx, buf + 31, 33);
    gm_sm3_update(&sm3_ctx, buf + 68, 32);

    // ouput s1 or sB
    gm_sm3_done(&sm3_ctx, output + kLen);

    // Hash(0x03 || y_u || Hash(x_u || Z_A || Z_B || x_1 || y_1 || x_2 || y_2))
    gm_sm3_init(&sm3_ctx);
    buf[31] = 0x03;
    gm_sm3_update(&sm3_ctx, buf + 31, 33);
    gm_sm3_update(&sm3_ctx, buf + 68, 32);

    // ouput s2 or sA
    gm_sm3_done(&sm3_ctx, output + kLen + 32);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

# checkhash

hash在calculate阶段已经计算完并且输出到output了,由业务层进行校验即可。

# 单元测试

测试代码:

void test_sm2_key_exch() {
    gm_bn_t da, tmpda;
    gm_bn_t db, tmpdb;
    gm_point_t pa, tmppa;
    gm_point_t pb, tmppb;

    unsigned char userId_a[3] = {0x61, 0x62, 0x63};
    unsigned char userId_b[5] = {0x61, 0x62, 0x63, 0x64, 0x65};

    unsigned char rp_a[128] = {0};
    unsigned char rp_b[128] = {0};

    unsigned char k_s1_sa[256] = {0};
    unsigned char k_sb_s2[256] = {0};

    gm_bn_from_hex(da, "1ED44070B763431D23D35A227A34D91558DC0B1EDD87E91238D4A54D98FAB6A0");
    gm_bn_from_hex(tmpda, "B45B1F0577C6D37C86F252B394B20E55FEEEF2DEE49743A68EC7871CECD89872");

    gm_bn_from_hex(db, "D18FE8EFD4E7C5B2FFDC356E16E397D2443DB6EA4C453EB5DC2852F8E301E846");
    gm_bn_from_hex(tmpdb, "37ED4CE7C7951B76BE93CFD116A9F8AE439664107A59278E0F7095B964A8C7BA");

    gm_point_mul(&pa, da, GM_MONT_G);
    gm_point_mul(&tmppa, tmpda, GM_MONT_G);

    gm_point_mul(&pb, db, GM_MONT_G);
    gm_point_mul(&tmppb, tmpdb, GM_MONT_G);

    gm_point_to_bytes(&pa, rp_a + 64);
    gm_point_to_bytes(&pb, rp_b + 64);

    gm_sm2_exch_context exa, exb;

    unsigned char expbuf[116] = {0};
    gm_hex2bin("3A18CB6BE2DC15C49998BE75DA28C4DEB3ADF33E08E886FCD7B2869CD006A6C4D5852D9E194A091EC9AC01B2D6B5153A09CA39BC3FB4984A09E4CE5B0DEC0E105CA12D712F6C8CBE59BFE54CAD0641B922D3EB0AD10C1D2347BA10985624ACC5A4C21400A3441D8EA5DE97B897B2635E6AEDE9", 
        230, expbuf);

    int i;

    for(i = 0; i < 100; i++) {
        // A为发起方,发起方初始化
        gm_sm2_exch_init_for_test(&exa, da, &pa, tmpda, &tmppa, 1, userId_a, 3, rp_a);

        // B为响应方
        gm_sm2_exch_init_for_test(&exb, db, &pb, tmpdb, &tmppb, 0, userId_b, 5, rp_b);

        // B拿到A的r z w进行密钥计算
        gm_sm2_exch_calculate(&exb, rp_a + 64, rp_a, userId_a, 3, 16 + i, k_sb_s2);

        // A拿到B的r z w进行密钥计算
        gm_sm2_exch_calculate(&exa, rp_b + 64, rp_b, userId_b, 5, 16 + i, k_s1_sa);
        // A校验s1 == sb
        if(memcmp(k_s1_sa + 16, k_sb_s2 + 16, 32) != 0) {
            printf("test result s1 == sb: fail\n");
            return;
        }

        // B校验s2 == sa
        if(memcmp(k_s1_sa + 16 + 32, k_sb_s2 + 16 + 32, 32) != 0) {
            printf("test result s2 == sa: fail\n");
            return;
        }

        // 最后来看看两方计算的密钥值是否一致
        if(memcmp(k_s1_sa, k_sb_s2, 16 + i) != 0) {
            printf("test result ka == kb: fail\n");
            return;
        }

        if(memcmp(k_s1_sa, expbuf, 16 + i) != 0) {
            printf("test result check k: fail\n");
            return;
        }
    }
    printf("test result: ok\n");
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

main函数增加:

if(strcmp(argv[1], "sm2_key_exch") == 0) {
    test_sm2_key_exch();
}
1
2
3

执行结果:

192:c saint$ time ./gm_test sm2_key_exch
test result: ok

real	0m2.206s
user	0m2.118s
sys	0m0.023s
1
2
3
4
5
6