diff options
Diffstat (limited to 'ssh-agent.c')
-rw-r--r-- | ssh-agent.c | 744 |
1 files changed, 276 insertions, 468 deletions
diff --git a/ssh-agent.c b/ssh-agent.c index b987562b9aa1..0c6c3659217f 100644 --- a/ssh-agent.c +++ b/ssh-agent.c @@ -1,4 +1,4 @@ -/* $OpenBSD: ssh-agent.c,v 1.218 2017/03/15 03:52:30 deraadt Exp $ */ +/* $OpenBSD: ssh-agent.c,v 1.224 2017/07/24 04:34:28 djm Exp $ */ /* * Author: Tatu Ylonen <ylo@cs.hut.fi> * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland @@ -60,6 +60,9 @@ #ifdef HAVE_PATHS_H # include <paths.h> #endif +#ifdef HAVE_POLL_H +# include <poll.h> +#endif #include <signal.h> #include <stdarg.h> #include <stdio.h> @@ -73,7 +76,6 @@ #include "xmalloc.h" #include "ssh.h" -#include "rsa.h" #include "sshbuf.h" #include "sshkey.h" #include "authfd.h" @@ -92,6 +94,9 @@ # define DEFAULT_PKCS11_WHITELIST "/usr/lib*/*,/usr/local/lib*/*" #endif +/* Maximum accepted message length */ +#define AGENT_MAX_LEN (256*1024) + typedef enum { AUTH_UNUSED, AUTH_SOCKET, @@ -118,13 +123,13 @@ typedef struct identity { u_int confirm; } Identity; -typedef struct { +struct idtable { int nentries; TAILQ_HEAD(idqueue, identity) idlist; -} Idtab; +}; -/* private key table, one per protocol version */ -Idtab idtable[3]; +/* private key table */ +struct idtable *idtab; int max_fd = 0; @@ -171,21 +176,9 @@ close_socket(SocketEntry *e) static void idtab_init(void) { - int i; - - for (i = 0; i <=2; i++) { - TAILQ_INIT(&idtable[i].idlist); - idtable[i].nentries = 0; - } -} - -/* return private key table for requested protocol version */ -static Idtab * -idtab_lookup(int version) -{ - if (version < 1 || version > 2) - fatal("internal error, bad protocol version %d", version); - return &idtable[version]; + idtab = xcalloc(1, sizeof(*idtab)); + TAILQ_INIT(&idtab->idlist); + idtab->nentries = 0; } static void @@ -199,12 +192,11 @@ free_identity(Identity *id) /* return matching private key for given public key */ static Identity * -lookup_identity(struct sshkey *key, int version) +lookup_identity(struct sshkey *key) { Identity *id; - Idtab *tab = idtab_lookup(version); - TAILQ_FOREACH(id, &tab->idlist, next) { + TAILQ_FOREACH(id, &idtab->idlist, next) { if (sshkey_equal(key, id->key)) return (id); } @@ -241,135 +233,30 @@ send_status(SocketEntry *e, int success) /* send list of supported public keys to 'client' */ static void -process_request_identities(SocketEntry *e, int version) +process_request_identities(SocketEntry *e) { - Idtab *tab = idtab_lookup(version); Identity *id; struct sshbuf *msg; int r; if ((msg = sshbuf_new()) == NULL) fatal("%s: sshbuf_new failed", __func__); - if ((r = sshbuf_put_u8(msg, (version == 1) ? - SSH_AGENT_RSA_IDENTITIES_ANSWER : - SSH2_AGENT_IDENTITIES_ANSWER)) != 0 || - (r = sshbuf_put_u32(msg, tab->nentries)) != 0) - fatal("%s: buffer error: %s", __func__, ssh_err(r)); - TAILQ_FOREACH(id, &tab->idlist, next) { - if (id->key->type == KEY_RSA1) { -#ifdef WITH_SSH1 - if ((r = sshbuf_put_u32(msg, - BN_num_bits(id->key->rsa->n))) != 0 || - (r = sshbuf_put_bignum1(msg, - id->key->rsa->e)) != 0 || - (r = sshbuf_put_bignum1(msg, - id->key->rsa->n)) != 0) - fatal("%s: buffer error: %s", - __func__, ssh_err(r)); -#endif - } else { - u_char *blob; - size_t blen; - - if ((r = sshkey_to_blob(id->key, &blob, &blen)) != 0) { - error("%s: sshkey_to_blob: %s", __func__, - ssh_err(r)); - continue; - } - if ((r = sshbuf_put_string(msg, blob, blen)) != 0) - fatal("%s: buffer error: %s", - __func__, ssh_err(r)); - free(blob); - } - if ((r = sshbuf_put_cstring(msg, id->comment)) != 0) - fatal("%s: buffer error: %s", __func__, ssh_err(r)); - } - if ((r = sshbuf_put_stringb(e->output, msg)) != 0) - fatal("%s: buffer error: %s", __func__, ssh_err(r)); - sshbuf_free(msg); -} - -#ifdef WITH_SSH1 -/* ssh1 only */ -static void -process_authentication_challenge1(SocketEntry *e) -{ - u_char buf[32], mdbuf[16], session_id[16]; - u_int response_type; - BIGNUM *challenge; - Identity *id; - int r, len; - struct sshbuf *msg; - struct ssh_digest_ctx *md; - struct sshkey *key; - - if ((msg = sshbuf_new()) == NULL) - fatal("%s: sshbuf_new failed", __func__); - if ((key = sshkey_new(KEY_RSA1)) == NULL) - fatal("%s: sshkey_new failed", __func__); - if ((challenge = BN_new()) == NULL) - fatal("%s: BN_new failed", __func__); - - if ((r = sshbuf_get_u32(e->request, NULL)) != 0 || /* ignored */ - (r = sshbuf_get_bignum1(e->request, key->rsa->e)) != 0 || - (r = sshbuf_get_bignum1(e->request, key->rsa->n)) != 0 || - (r = sshbuf_get_bignum1(e->request, challenge))) - fatal("%s: buffer error: %s", __func__, ssh_err(r)); - - /* Only protocol 1.1 is supported */ - if (sshbuf_len(e->request) == 0) - goto failure; - if ((r = sshbuf_get(e->request, session_id, sizeof(session_id))) != 0 || - (r = sshbuf_get_u32(e->request, &response_type)) != 0) + if ((r = sshbuf_put_u8(msg, SSH2_AGENT_IDENTITIES_ANSWER)) != 0 || + (r = sshbuf_put_u32(msg, idtab->nentries)) != 0) fatal("%s: buffer error: %s", __func__, ssh_err(r)); - if (response_type != 1) - goto failure; - - id = lookup_identity(key, 1); - if (id != NULL && (!id->confirm || confirm_key(id) == 0)) { - struct sshkey *private = id->key; - /* Decrypt the challenge using the private key. */ - if ((r = rsa_private_decrypt(challenge, challenge, - private->rsa) != 0)) { - fatal("%s: rsa_public_encrypt: %s", __func__, + TAILQ_FOREACH(id, &idtab->idlist, next) { + if ((r = sshkey_puts(id->key, msg)) != 0 || + (r = sshbuf_put_cstring(msg, id->comment)) != 0) { + error("%s: put key/comment: %s", __func__, ssh_err(r)); - goto failure; /* XXX ? */ + continue; } - - /* The response is MD5 of decrypted challenge plus session id */ - len = BN_num_bytes(challenge); - if (len <= 0 || len > 32) { - logit("%s: bad challenge length %d", __func__, len); - goto failure; - } - memset(buf, 0, 32); - BN_bn2bin(challenge, buf + 32 - len); - if ((md = ssh_digest_start(SSH_DIGEST_MD5)) == NULL || - ssh_digest_update(md, buf, 32) < 0 || - ssh_digest_update(md, session_id, 16) < 0 || - ssh_digest_final(md, mdbuf, sizeof(mdbuf)) < 0) - fatal("%s: md5 failed", __func__); - ssh_digest_free(md); - - /* Send the response. */ - if ((r = sshbuf_put_u8(msg, SSH_AGENT_RSA_RESPONSE)) != 0 || - (r = sshbuf_put(msg, mdbuf, sizeof(mdbuf))) != 0) - fatal("%s: buffer error: %s", __func__, ssh_err(r)); - goto send; } - - failure: - /* Unknown identity or protocol error. Send failure. */ - if ((r = sshbuf_put_u8(msg, SSH_AGENT_FAILURE)) != 0) - fatal("%s: buffer error: %s", __func__, ssh_err(r)); - send: if ((r = sshbuf_put_stringb(e->output, msg)) != 0) fatal("%s: buffer error: %s", __func__, ssh_err(r)); - sshkey_free(key); - BN_clear_free(challenge); sshbuf_free(msg); } -#endif + static char * agent_decode_alg(struct sshkey *key, u_int flags) @@ -387,27 +274,24 @@ agent_decode_alg(struct sshkey *key, u_int flags) static void process_sign_request2(SocketEntry *e) { - u_char *blob, *data, *signature = NULL; - size_t blen, dlen, slen = 0; + const u_char *data; + u_char *signature = NULL; + size_t dlen, slen = 0; u_int compat = 0, flags; int r, ok = -1; struct sshbuf *msg; - struct sshkey *key; + struct sshkey *key = NULL; struct identity *id; if ((msg = sshbuf_new()) == NULL) fatal("%s: sshbuf_new failed", __func__); - if ((r = sshbuf_get_string(e->request, &blob, &blen)) != 0 || - (r = sshbuf_get_string(e->request, &data, &dlen)) != 0 || + if ((r = sshkey_froms(e->request, &key)) != 0 || + (r = sshbuf_get_string_direct(e->request, &data, &dlen)) != 0 || (r = sshbuf_get_u32(e->request, &flags)) != 0) fatal("%s: buffer error: %s", __func__, ssh_err(r)); if (flags & SSH_AGENT_OLD_SIGNATURE) compat = SSH_BUG_SIGBLOB; - if ((r = sshkey_from_blob(blob, blen, &key)) != 0) { - error("%s: cannot parse key blob: %s", __func__, ssh_err(r)); - goto send; - } - if ((id = lookup_identity(key, 2)) == NULL) { + if ((id = lookup_identity(key)) == NULL) { verbose("%s: %s key not found", __func__, sshkey_type(key)); goto send; } @@ -435,90 +319,52 @@ process_sign_request2(SocketEntry *e) fatal("%s: buffer error: %s", __func__, ssh_err(r)); sshbuf_free(msg); - free(data); - free(blob); free(signature); } /* shared */ static void -process_remove_identity(SocketEntry *e, int version) +process_remove_identity(SocketEntry *e) { - size_t blen; int r, success = 0; struct sshkey *key = NULL; - u_char *blob; -#ifdef WITH_SSH1 - u_int bits; -#endif /* WITH_SSH1 */ - - switch (version) { -#ifdef WITH_SSH1 - case 1: - if ((key = sshkey_new(KEY_RSA1)) == NULL) { - error("%s: sshkey_new failed", __func__); - return; - } - if ((r = sshbuf_get_u32(e->request, &bits)) != 0 || - (r = sshbuf_get_bignum1(e->request, key->rsa->e)) != 0 || - (r = sshbuf_get_bignum1(e->request, key->rsa->n)) != 0) - fatal("%s: buffer error: %s", __func__, ssh_err(r)); + Identity *id; - if (bits != sshkey_size(key)) - logit("Warning: identity keysize mismatch: " - "actual %u, announced %u", - sshkey_size(key), bits); - break; -#endif /* WITH_SSH1 */ - case 2: - if ((r = sshbuf_get_string(e->request, &blob, &blen)) != 0) - fatal("%s: buffer error: %s", __func__, ssh_err(r)); - if ((r = sshkey_from_blob(blob, blen, &key)) != 0) - error("%s: sshkey_from_blob failed: %s", - __func__, ssh_err(r)); - free(blob); - break; + if ((r = sshkey_froms(e->request, &key)) != 0) { + error("%s: get key: %s", __func__, ssh_err(r)); + goto done; } - if (key != NULL) { - Identity *id = lookup_identity(key, version); - if (id != NULL) { - /* - * We have this key. Free the old key. Since we - * don't want to leave empty slots in the middle of - * the array, we actually free the key there and move - * all the entries between the empty slot and the end - * of the array. - */ - Idtab *tab = idtab_lookup(version); - if (tab->nentries < 1) - fatal("process_remove_identity: " - "internal error: tab->nentries %d", - tab->nentries); - TAILQ_REMOVE(&tab->idlist, id, next); - free_identity(id); - tab->nentries--; - success = 1; - } - sshkey_free(key); + if ((id = lookup_identity(key)) == NULL) { + debug("%s: key not found", __func__); + goto done; } + /* We have this key, free it. */ + if (idtab->nentries < 1) + fatal("%s: internal error: nentries %d", + __func__, idtab->nentries); + TAILQ_REMOVE(&idtab->idlist, id, next); + free_identity(id); + idtab->nentries--; + sshkey_free(key); + success = 1; + done: send_status(e, success); } static void -process_remove_all_identities(SocketEntry *e, int version) +process_remove_all_identities(SocketEntry *e) { - Idtab *tab = idtab_lookup(version); Identity *id; /* Loop over all identities and clear the keys. */ - for (id = TAILQ_FIRST(&tab->idlist); id; - id = TAILQ_FIRST(&tab->idlist)) { - TAILQ_REMOVE(&tab->idlist, id, next); + for (id = TAILQ_FIRST(&idtab->idlist); id; + id = TAILQ_FIRST(&idtab->idlist)) { + TAILQ_REMOVE(&idtab->idlist, id, next); free_identity(id); } /* Mark that there are no identities. */ - tab->nentries = 0; + idtab->nentries = 0; /* Send success. */ send_status(e, 1); @@ -530,24 +376,19 @@ reaper(void) { time_t deadline = 0, now = monotime(); Identity *id, *nxt; - int version; - Idtab *tab; - - for (version = 1; version < 3; version++) { - tab = idtab_lookup(version); - for (id = TAILQ_FIRST(&tab->idlist); id; id = nxt) { - nxt = TAILQ_NEXT(id, next); - if (id->death == 0) - continue; - if (now >= id->death) { - debug("expiring key '%s'", id->comment); - TAILQ_REMOVE(&tab->idlist, id, next); - free_identity(id); - tab->nentries--; - } else - deadline = (deadline == 0) ? id->death : - MINIMUM(deadline, id->death); - } + + for (id = TAILQ_FIRST(&idtab->idlist); id; id = nxt) { + nxt = TAILQ_NEXT(id, next); + if (id->death == 0) + continue; + if (now >= id->death) { + debug("expiring key '%s'", id->comment); + TAILQ_REMOVE(&idtab->idlist, id, next); + free_identity(id); + idtab->nentries--; + } else + deadline = (deadline == 0) ? id->death : + MINIMUM(deadline, id->death); } if (deadline == 0 || deadline <= now) return 0; @@ -555,54 +396,9 @@ reaper(void) return (deadline - now); } -/* - * XXX this and the corresponding serialisation function probably belongs - * in key.c - */ -#ifdef WITH_SSH1 -static int -agent_decode_rsa1(struct sshbuf *m, struct sshkey **kp) -{ - struct sshkey *k = NULL; - int r = SSH_ERR_INTERNAL_ERROR; - - *kp = NULL; - if ((k = sshkey_new_private(KEY_RSA1)) == NULL) - return SSH_ERR_ALLOC_FAIL; - - if ((r = sshbuf_get_u32(m, NULL)) != 0 || /* ignored */ - (r = sshbuf_get_bignum1(m, k->rsa->n)) != 0 || - (r = sshbuf_get_bignum1(m, k->rsa->e)) != 0 || - (r = sshbuf_get_bignum1(m, k->rsa->d)) != 0 || - (r = sshbuf_get_bignum1(m, k->rsa->iqmp)) != 0 || - /* SSH1 and SSL have p and q swapped */ - (r = sshbuf_get_bignum1(m, k->rsa->q)) != 0 || /* p */ - (r = sshbuf_get_bignum1(m, k->rsa->p)) != 0) /* q */ - goto out; - - /* Generate additional parameters */ - if ((r = rsa_generate_additional_parameters(k->rsa)) != 0) - goto out; - /* enable blinding */ - if (RSA_blinding_on(k->rsa, NULL) != 1) { - r = SSH_ERR_LIBCRYPTO_ERROR; - goto out; - } - - r = 0; /* success */ - out: - if (r == 0) - *kp = k; - else - sshkey_free(k); - return r; -} -#endif /* WITH_SSH1 */ - static void -process_add_identity(SocketEntry *e, int version) +process_add_identity(SocketEntry *e) { - Idtab *tab = idtab_lookup(version); Identity *id; int success = 0, confirm = 0; u_int seconds; @@ -612,17 +408,8 @@ process_add_identity(SocketEntry *e, int version) u_char ctype; int r = SSH_ERR_INTERNAL_ERROR; - switch (version) { -#ifdef WITH_SSH1 - case 1: - r = agent_decode_rsa1(e->request, &k); - break; -#endif /* WITH_SSH1 */ - case 2: - r = sshkey_private_deserialize(e->request, &k); - break; - } - if (r != 0 || k == NULL || + if ((r = sshkey_private_deserialize(e->request, &k)) != 0 || + k == NULL || (r = sshbuf_get_cstring(e->request, &comment, NULL)) != 0) { error("%s: decode private key: %s", __func__, ssh_err(r)); goto err; @@ -658,12 +445,12 @@ process_add_identity(SocketEntry *e, int version) success = 1; if (lifetime && !death) death = monotime() + lifetime; - if ((id = lookup_identity(k, version)) == NULL) { + if ((id = lookup_identity(k)) == NULL) { id = xcalloc(1, sizeof(Identity)); id->key = k; - TAILQ_INSERT_TAIL(&tab->idlist, id, next); + TAILQ_INSERT_TAIL(&idtab->idlist, id, next); /* Increment the number of identities. */ - tab->nentries++; + idtab->nentries++; } else { sshkey_free(k); free(id->comment); @@ -724,17 +511,14 @@ process_lock_agent(SocketEntry *e, int lock) } static void -no_identities(SocketEntry *e, u_int type) +no_identities(SocketEntry *e) { struct sshbuf *msg; int r; if ((msg = sshbuf_new()) == NULL) fatal("%s: sshbuf_new failed", __func__); - if ((r = sshbuf_put_u8(msg, - (type == SSH_AGENTC_REQUEST_RSA_IDENTITIES) ? - SSH_AGENT_RSA_IDENTITIES_ANSWER : - SSH2_AGENT_IDENTITIES_ANSWER)) != 0 || + if ((r = sshbuf_put_u8(msg, SSH2_AGENT_IDENTITIES_ANSWER)) != 0 || (r = sshbuf_put_u32(msg, 0)) != 0 || (r = sshbuf_put_stringb(e->output, msg)) != 0) fatal("%s: buffer error: %s", __func__, ssh_err(r)); @@ -746,13 +530,12 @@ static void process_add_smartcard_key(SocketEntry *e) { char *provider = NULL, *pin, canonical_provider[PATH_MAX]; - int r, i, version, count = 0, success = 0, confirm = 0; + int r, i, count = 0, success = 0, confirm = 0; u_int seconds; time_t death = 0; u_char type; struct sshkey **keys = NULL, *k; Identity *id; - Idtab *tab; if ((r = sshbuf_get_cstring(e->request, &provider, NULL)) != 0 || (r = sshbuf_get_cstring(e->request, &pin, NULL)) != 0) @@ -772,8 +555,7 @@ process_add_smartcard_key(SocketEntry *e) confirm = 1; break; default: - error("process_add_smartcard_key: " - "Unknown constraint type %d", type); + error("%s: Unknown constraint type %d", __func__, type); goto send; } } @@ -794,17 +576,15 @@ process_add_smartcard_key(SocketEntry *e) count = pkcs11_add_provider(canonical_provider, pin, &keys); for (i = 0; i < count; i++) { k = keys[i]; - version = k->type == KEY_RSA1 ? 1 : 2; - tab = idtab_lookup(version); - if (lookup_identity(k, version) == NULL) { + if (lookup_identity(k) == NULL) { id = xcalloc(1, sizeof(Identity)); id->key = k; id->provider = xstrdup(canonical_provider); id->comment = xstrdup(canonical_provider); /* XXX */ id->death = death; id->confirm = confirm; - TAILQ_INSERT_TAIL(&tab->idlist, id, next); - tab->nentries++; + TAILQ_INSERT_TAIL(&idtab->idlist, id, next); + idtab->nentries++; success = 1; } else { sshkey_free(k); @@ -822,9 +602,8 @@ static void process_remove_smartcard_key(SocketEntry *e) { char *provider = NULL, *pin = NULL, canonical_provider[PATH_MAX]; - int r, version, success = 0; + int r, success = 0; Identity *id, *nxt; - Idtab *tab; if ((r = sshbuf_get_cstring(e->request, &provider, NULL)) != 0 || (r = sshbuf_get_cstring(e->request, &pin, NULL)) != 0) @@ -838,25 +617,21 @@ process_remove_smartcard_key(SocketEntry *e) } debug("%s: remove %.100s", __func__, canonical_provider); - for (version = 1; version < 3; version++) { - tab = idtab_lookup(version); - for (id = TAILQ_FIRST(&tab->idlist); id; id = nxt) { - nxt = TAILQ_NEXT(id, next); - /* Skip file--based keys */ - if (id->provider == NULL) - continue; - if (!strcmp(canonical_provider, id->provider)) { - TAILQ_REMOVE(&tab->idlist, id, next); - free_identity(id); - tab->nentries--; - } + for (id = TAILQ_FIRST(&idtab->idlist); id; id = nxt) { + nxt = TAILQ_NEXT(id, next); + /* Skip file--based keys */ + if (id->provider == NULL) + continue; + if (!strcmp(canonical_provider, id->provider)) { + TAILQ_REMOVE(&idtab->idlist, id, next); + free_identity(id); + idtab->nentries--; } } if (pkcs11_del_provider(canonical_provider) == 0) success = 1; else - error("process_remove_smartcard_key:" - " pkcs11_del_provider failed"); + error("%s: pkcs11_del_provider failed", __func__); send: free(provider); send_status(e, success); @@ -865,88 +640,86 @@ send: /* dispatch incoming messages */ -static void -process_message(SocketEntry *e) +static int +process_message(u_int socknum) { u_int msg_len; u_char type; const u_char *cp; int r; + SocketEntry *e; + + if (socknum >= sockets_alloc) { + fatal("%s: socket number %u >= allocated %u", + __func__, socknum, sockets_alloc); + } + e = &sockets[socknum]; if (sshbuf_len(e->input) < 5) - return; /* Incomplete message. */ + return 0; /* Incomplete message header. */ cp = sshbuf_ptr(e->input); msg_len = PEEK_U32(cp); - if (msg_len > 256 * 1024) { - close_socket(e); - return; + if (msg_len > AGENT_MAX_LEN) { + debug("%s: socket %u (fd=%d) message too long %u > %u", + __func__, socknum, e->fd, msg_len, AGENT_MAX_LEN); + return -1; } if (sshbuf_len(e->input) < msg_len + 4) - return; + return 0; /* Incomplete message body. */ /* move the current input to e->request */ sshbuf_reset(e->request); if ((r = sshbuf_get_stringb(e->input, e->request)) != 0 || - (r = sshbuf_get_u8(e->request, &type)) != 0) + (r = sshbuf_get_u8(e->request, &type)) != 0) { + if (r == SSH_ERR_MESSAGE_INCOMPLETE || + r == SSH_ERR_STRING_TOO_LARGE) { + debug("%s: buffer error: %s", __func__, ssh_err(r)); + return -1; + } fatal("%s: buffer error: %s", __func__, ssh_err(r)); + } + + debug("%s: socket %u (fd=%d) type %d", __func__, socknum, e->fd, type); /* check wheter agent is locked */ if (locked && type != SSH_AGENTC_UNLOCK) { sshbuf_reset(e->request); switch (type) { - case SSH_AGENTC_REQUEST_RSA_IDENTITIES: case SSH2_AGENTC_REQUEST_IDENTITIES: /* send empty lists */ - no_identities(e, type); + no_identities(e); break; default: /* send a fail message for all other request types */ send_status(e, 0); } - return; + return 0; } - debug("type %d", type); switch (type) { case SSH_AGENTC_LOCK: case SSH_AGENTC_UNLOCK: process_lock_agent(e, type == SSH_AGENTC_LOCK); break; -#ifdef WITH_SSH1 - /* ssh1 */ - case SSH_AGENTC_RSA_CHALLENGE: - process_authentication_challenge1(e); - break; - case SSH_AGENTC_REQUEST_RSA_IDENTITIES: - process_request_identities(e, 1); - break; - case SSH_AGENTC_ADD_RSA_IDENTITY: - case SSH_AGENTC_ADD_RSA_ID_CONSTRAINED: - process_add_identity(e, 1); - break; - case SSH_AGENTC_REMOVE_RSA_IDENTITY: - process_remove_identity(e, 1); - break; -#endif case SSH_AGENTC_REMOVE_ALL_RSA_IDENTITIES: - process_remove_all_identities(e, 1); /* safe for !WITH_SSH1 */ + process_remove_all_identities(e); /* safe for !WITH_SSH1 */ break; /* ssh2 */ case SSH2_AGENTC_SIGN_REQUEST: process_sign_request2(e); break; case SSH2_AGENTC_REQUEST_IDENTITIES: - process_request_identities(e, 2); + process_request_identities(e); break; case SSH2_AGENTC_ADD_IDENTITY: case SSH2_AGENTC_ADD_ID_CONSTRAINED: - process_add_identity(e, 2); + process_add_identity(e); break; case SSH2_AGENTC_REMOVE_IDENTITY: - process_remove_identity(e, 2); + process_remove_identity(e); break; case SSH2_AGENTC_REMOVE_ALL_IDENTITIES: - process_remove_all_identities(e, 2); + process_remove_all_identities(e); break; #ifdef ENABLE_PKCS11 case SSH_AGENTC_ADD_SMARTCARD_KEY: @@ -964,6 +737,7 @@ process_message(SocketEntry *e) send_status(e, 0); break; } + return 0; } static void @@ -1005,19 +779,141 @@ new_socket(sock_type type, int fd) } static int -prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp, - struct timeval **tvpp) +handle_socket_read(u_int socknum) +{ + struct sockaddr_un sunaddr; + socklen_t slen; + uid_t euid; + gid_t egid; + int fd; + + slen = sizeof(sunaddr); + fd = accept(sockets[socknum].fd, (struct sockaddr *)&sunaddr, &slen); + if (fd < 0) { + error("accept from AUTH_SOCKET: %s", strerror(errno)); + return -1; + } + if (getpeereid(fd, &euid, &egid) < 0) { + error("getpeereid %d failed: %s", fd, strerror(errno)); + close(fd); + return -1; + } + if ((euid != 0) && (getuid() != euid)) { + error("uid mismatch: peer euid %u != uid %u", + (u_int) euid, (u_int) getuid()); + close(fd); + return -1; + } + new_socket(AUTH_CONNECTION, fd); + return 0; +} + +static int +handle_conn_read(u_int socknum) +{ + char buf[1024]; + ssize_t len; + int r; + + if ((len = read(sockets[socknum].fd, buf, sizeof(buf))) <= 0) { + if (len == -1) { + if (errno == EAGAIN || errno == EINTR) + return 0; + error("%s: read error on socket %u (fd %d): %s", + __func__, socknum, sockets[socknum].fd, + strerror(errno)); + } + return -1; + } + if ((r = sshbuf_put(sockets[socknum].input, buf, len)) != 0) + fatal("%s: buffer error: %s", __func__, ssh_err(r)); + explicit_bzero(buf, sizeof(buf)); + process_message(socknum); + return 0; +} + +static int +handle_conn_write(u_int socknum) { - u_int i, sz; - int n = 0; - static struct timeval tv; + ssize_t len; + int r; + + if (sshbuf_len(sockets[socknum].output) == 0) + return 0; /* shouldn't happen */ + if ((len = write(sockets[socknum].fd, + sshbuf_ptr(sockets[socknum].output), + sshbuf_len(sockets[socknum].output))) <= 0) { + if (len == -1) { + if (errno == EAGAIN || errno == EINTR) + return 0; + error("%s: read error on socket %u (fd %d): %s", + __func__, socknum, sockets[socknum].fd, + strerror(errno)); + } + return -1; + } + if ((r = sshbuf_consume(sockets[socknum].output, len)) != 0) + fatal("%s: buffer error: %s", __func__, ssh_err(r)); + return 0; +} + +static void +after_poll(struct pollfd *pfd, size_t npfd) +{ + size_t i; + u_int socknum; + + for (i = 0; i < npfd; i++) { + if (pfd[i].revents == 0) + continue; + /* Find sockets entry */ + for (socknum = 0; socknum < sockets_alloc; socknum++) { + if (sockets[socknum].type != AUTH_SOCKET && + sockets[socknum].type != AUTH_CONNECTION) + continue; + if (pfd[i].fd == sockets[socknum].fd) + break; + } + if (socknum >= sockets_alloc) { + error("%s: no socket for fd %d", __func__, pfd[i].fd); + continue; + } + /* Process events */ + switch (sockets[socknum].type) { + case AUTH_SOCKET: + if ((pfd[i].revents & (POLLIN|POLLERR)) != 0 && + handle_socket_read(socknum) != 0) + close_socket(&sockets[socknum]); + break; + case AUTH_CONNECTION: + if ((pfd[i].revents & (POLLIN|POLLERR)) != 0 && + handle_conn_read(socknum) != 0) { + close_socket(&sockets[socknum]); + break; + } + if ((pfd[i].revents & (POLLOUT|POLLHUP)) != 0 && + handle_conn_write(socknum) != 0) + close_socket(&sockets[socknum]); + break; + default: + break; + } + } +} + +static int +prepare_poll(struct pollfd **pfdp, size_t *npfdp, int *timeoutp) +{ + struct pollfd *pfd = *pfdp; + size_t i, j, npfd = 0; time_t deadline; + /* Count active sockets */ for (i = 0; i < sockets_alloc; i++) { switch (sockets[i].type) { case AUTH_SOCKET: case AUTH_CONNECTION: - n = MAXIMUM(n, sockets[i].fd); + npfd++; break; case AUTH_UNUSED: break; @@ -1026,28 +922,23 @@ prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp, break; } } + if (npfd != *npfdp && + (pfd = recallocarray(pfd, *npfdp, npfd, sizeof(*pfd))) == NULL) + fatal("%s: recallocarray failed", __func__); + *pfdp = pfd; + *npfdp = npfd; - sz = howmany(n+1, NFDBITS) * sizeof(fd_mask); - if (*fdrp == NULL || sz > *nallocp) { - free(*fdrp); - free(*fdwp); - *fdrp = xmalloc(sz); - *fdwp = xmalloc(sz); - *nallocp = sz; - } - if (n < *fdl) - debug("XXX shrink: %d < %d", n, *fdl); - *fdl = n; - memset(*fdrp, 0, sz); - memset(*fdwp, 0, sz); - - for (i = 0; i < sockets_alloc; i++) { + for (i = j = 0; i < sockets_alloc; i++) { switch (sockets[i].type) { case AUTH_SOCKET: case AUTH_CONNECTION: - FD_SET(sockets[i].fd, *fdrp); + pfd[j].fd = sockets[i].fd; + pfd[j].revents = 0; + /* XXX backoff when input buffer full */ + pfd[j].events = POLLIN; if (sshbuf_len(sockets[i].output) > 0) - FD_SET(sockets[i].fd, *fdwp); + pfd[j].events |= POLLOUT; + j++; break; default: break; @@ -1058,99 +949,17 @@ prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp, deadline = (deadline == 0) ? parent_alive_interval : MINIMUM(deadline, parent_alive_interval); if (deadline == 0) { - *tvpp = NULL; + *timeoutp = -1; /* INFTIM */ } else { - tv.tv_sec = deadline; - tv.tv_usec = 0; - *tvpp = &tv; + if (deadline > INT_MAX / 1000) + *timeoutp = INT_MAX / 1000; + else + *timeoutp = deadline * 1000; } return (1); } static void -after_select(fd_set *readset, fd_set *writeset) -{ - struct sockaddr_un sunaddr; - socklen_t slen; - char buf[1024]; - int len, sock, r; - u_int i, orig_alloc; - uid_t euid; - gid_t egid; - - for (i = 0, orig_alloc = sockets_alloc; i < orig_alloc; i++) - switch (sockets[i].type) { - case AUTH_UNUSED: - break; - case AUTH_SOCKET: - if (FD_ISSET(sockets[i].fd, readset)) { - slen = sizeof(sunaddr); - sock = accept(sockets[i].fd, - (struct sockaddr *)&sunaddr, &slen); - if (sock < 0) { - error("accept from AUTH_SOCKET: %s", - strerror(errno)); - break; - } - if (getpeereid(sock, &euid, &egid) < 0) { - error("getpeereid %d failed: %s", - sock, strerror(errno)); - close(sock); - break; - } - if ((euid != 0) && (getuid() != euid)) { - error("uid mismatch: " - "peer euid %u != uid %u", - (u_int) euid, (u_int) getuid()); - close(sock); - break; - } - new_socket(AUTH_CONNECTION, sock); - } - break; - case AUTH_CONNECTION: - if (sshbuf_len(sockets[i].output) > 0 && - FD_ISSET(sockets[i].fd, writeset)) { - len = write(sockets[i].fd, - sshbuf_ptr(sockets[i].output), - sshbuf_len(sockets[i].output)); - if (len == -1 && (errno == EAGAIN || - errno == EWOULDBLOCK || - errno == EINTR)) - continue; - if (len <= 0) { - close_socket(&sockets[i]); - break; - } - if ((r = sshbuf_consume(sockets[i].output, - len)) != 0) - fatal("%s: buffer error: %s", - __func__, ssh_err(r)); - } - if (FD_ISSET(sockets[i].fd, readset)) { - len = read(sockets[i].fd, buf, sizeof(buf)); - if (len == -1 && (errno == EAGAIN || - errno == EWOULDBLOCK || - errno == EINTR)) - continue; - if (len <= 0) { - close_socket(&sockets[i]); - break; - } - if ((r = sshbuf_put(sockets[i].input, - buf, len)) != 0) - fatal("%s: buffer error: %s", - __func__, ssh_err(r)); - explicit_bzero(buf, sizeof(buf)); - process_message(&sockets[i]); - } - break; - default: - fatal("Unknown type %d", sockets[i].type); - } -} - -static void cleanup_socket(void) { if (cleanup_pid != 0 && getpid() != cleanup_pid) @@ -1209,9 +1018,7 @@ main(int ac, char **av) { int c_flag = 0, d_flag = 0, D_flag = 0, k_flag = 0, s_flag = 0; int sock, fd, ch, result, saved_errno; - u_int nalloc; char *shell, *format, *pidstr, *agentsocket = NULL; - fd_set *readsetp = NULL, *writesetp = NULL; #ifdef HAVE_SETRLIMIT struct rlimit rlim; #endif @@ -1219,9 +1026,11 @@ main(int ac, char **av) extern char *optarg; pid_t pid; char pidstrbuf[1 + 3 * sizeof pid]; - struct timeval *tvp = NULL; size_t len; mode_t prev_mask; + int timeout = -1; /* INFTIM */ + struct pollfd *pfd = NULL; + size_t npfd = 0; ssh_malloc_init(); /* must be called before any mallocs */ /* Ensure that fds 0, 1 and 2 are open or directed to /dev/null */ @@ -1442,15 +1251,14 @@ skip: signal(SIGINT, (d_flag | D_flag) ? cleanup_handler : SIG_IGN); signal(SIGHUP, cleanup_handler); signal(SIGTERM, cleanup_handler); - nalloc = 0; if (pledge("stdio rpath cpath unix id proc exec", NULL) == -1) fatal("%s: pledge: %s", __progname, strerror(errno)); platform_pledge_agent(); while (1) { - prepare_select(&readsetp, &writesetp, &max_fd, &nalloc, &tvp); - result = select(max_fd + 1, readsetp, writesetp, NULL, tvp); + prepare_poll(&pfd, &npfd, &timeout); + result = poll(pfd, npfd, timeout); saved_errno = errno; if (parent_alive_interval != 0) check_parent_exists(); @@ -1458,9 +1266,9 @@ skip: if (result < 0) { if (saved_errno == EINTR) continue; - fatal("select: %s", strerror(saved_errno)); + fatal("poll: %s", strerror(saved_errno)); } else if (result > 0) - after_select(readsetp, writesetp); + after_poll(pfd, npfd); } /* NOTREACHED */ } |