diff --git a/mtproto-client.c b/mtproto-client.c index d4521e5..3932233 100644 --- a/mtproto-client.c +++ b/mtproto-client.c @@ -416,6 +416,77 @@ int process_respq_answer (struct connection *c, char *packet, int len) { return rpc_send_packet (c); } +int check_DH_params (BIGNUM *p, int g) { + if (g < 2 || g > 7) { return -1; } + BIGNUM t; + BN_init (&t); + + BN_init (&dh_g); + BN_set_word (&dh_g, 4 * g); + + BN_mod (&t, p, &dh_g, BN_ctx); + int x = BN_get_word (&t); + assert (x >= 0 && x < 4 * g); + + BN_clear (&dh_g); + + switch (g) { + case 2: + if (x != 7) { return -1; } + break; + case 3: + if (x % 3 != 2 ) { return -1; } + break; + case 4: + break; + case 5: + if (x % 5 != 1 && x % 5 != 4) { return -1; } + break; + case 6: + if (x != 19 && x != 23) { return -1; } + break; + case 7: + if (x % 7 != 3 && x % 7 != 5 && x % 7 != 6) { return -1; } + break; + } + + if (!BN_is_prime (p, BN_prime_checks, 0, BN_ctx, 0)) { return -1; } + + BIGNUM b; + BN_init (&b); + BN_set_word (&b, 2); + BN_div (&t, 0, p, &b, BN_ctx); + if (!BN_is_prime (&t, BN_prime_checks, 0, BN_ctx, 0)) { return -1; } + BN_clear (&b); + BN_clear (&t); + return 0; +} + +int check_g (BIGNUM *g) { + static unsigned char s[256]; + memset (s, 0, 256); + assert (BN_num_bytes (g) <= 256); + BN_bn2bin (g, s); + int ok = 0; + int i; + for (i = 0; i < 64; i++) { + if (s[i]) { + ok = 1; + break; + } + } + if (!ok) { return -1; } + ok = 0; + for (i = 0; i < 64; i++) { + if (s[255 - i]) { + ok = 1; + break; + } + } + if (!ok) { return -1; } + return 0; +} + int process_dh_answer (struct connection *c, char *packet, int len) { if (verbosity) { logprintf ( "process_dh_answer(), len=%d\n", len); @@ -448,9 +519,12 @@ int process_dh_answer (struct connection *c, char *packet, int len) { BN_init (&g_a); assert (fetch_bignum (&dh_prime) > 0); assert (fetch_bignum (&g_a) > 0); + assert (check_g (&g_a) >= 0); int server_time = *in_ptr++; assert (in_ptr <= in_end); + assert (check_DH_params (&dh_prime, g) >= 0); + static char sha1_buffer[20]; sha1 ((unsigned char *) decrypt_buffer + 20, (in_ptr - decrypt_buffer - 5) * 4, (unsigned char *) sha1_buffer); assert (!memcmp (decrypt_buffer, sha1_buffer, 20));