diff --git a/qrexec/qrexec-client.c b/qrexec/qrexec-client.c index cfa7bd3..390b022 100644 --- a/qrexec/qrexec-client.c +++ b/qrexec/qrexec-client.c @@ -170,10 +170,14 @@ void do_exec(const char *prog) static void do_exit(int code) { int status; - // sever communication lines; wait for child, if any - // so that qrexec-daemon can count (recursively) spawned processes correctly + /* restore flags, as we may have not the only copy of this file descriptor + */ + if (local_stdin_fd != -1) + set_block(local_stdin_fd); close(local_stdin_fd); close(local_stdout_fd); + // sever communication lines; wait for child, if any + // so that qrexec-daemon can count (recursively) spawned processes correctly waitpid(-1, &status, 0); exit(code); } @@ -269,9 +273,15 @@ static void handle_input(libvchan_t *vchan) { char buf[MAX_DATA_CHUNK]; int ret; + size_t max_len; struct msg_header hdr; - ret = read(local_stdout_fd, buf, sizeof(buf)); + max_len = libvchan_buffer_space(vchan)-sizeof(hdr); + if (max_len > sizeof(buf)) + max_len = sizeof(buf); + if (max_len == 0) + return; + ret = read(local_stdout_fd, buf, max_len); if (ret < 0) { perror("read"); do_exit(1); @@ -328,12 +338,25 @@ void do_replace_esc(char *buf, int len) { buf[i] = '_'; } -static void handle_vchan_data(libvchan_t *vchan) +static int handle_vchan_data(libvchan_t *vchan, struct buffer *stdin_buf) { int status; struct msg_header hdr; char buf[MAX_DATA_CHUNK]; + if (local_stdin_fd != -1) { + switch(flush_client_data(local_stdin_fd, stdin_buf)) { + case WRITE_STDIN_ERROR: + perror("write stdin"); + close(local_stdin_fd); + local_stdin_fd = -1; + break; + case WRITE_STDIN_BUFFERED: + return WRITE_STDIN_BUFFERED; + case WRITE_STDIN_OK: + break; + } + } if (libvchan_recv(vchan, &hdr, sizeof hdr) < 0) { perror("read vchan"); do_exit(1); @@ -356,16 +379,29 @@ static void handle_vchan_data(libvchan_t *vchan) if (replace_esc_stdout) do_replace_esc(buf, hdr.len); if (hdr.len == 0) { + /* restore flags, as we may have not the only copy of this file descriptor + */ + if (local_stdin_fd != -1) + set_block(local_stdin_fd); close(local_stdin_fd); local_stdin_fd = -1; - } else if (!write_all(local_stdin_fd, buf, hdr.len)) { - if (errno == EPIPE) { - // remote side have closed its stdin, handle data in oposite - // direction (if any) before exit - local_stdin_fd = -1; - } else { - perror("write local stdout"); - do_exit(1); + } else { + switch (write_stdin(local_stdin_fd, buf, hdr.len, stdin_buf)) { + case WRITE_STDIN_BUFFERED: + return WRITE_STDIN_BUFFERED; + case WRITE_STDIN_ERROR: + if (errno == EPIPE) { + // local process have closed its stdin, handle data in oposite + // direction (if any) before exit + close(local_stdin_fd); + local_stdin_fd = -1; + } else { + perror("write local stdout"); + do_exit(1); + } + break; + case WRITE_STDIN_OK: + break; } } break; @@ -381,12 +417,17 @@ static void handle_vchan_data(libvchan_t *vchan) else memcpy(&status, buf, sizeof(status)); + flush_client_data(local_stdin_fd, stdin_buf); do_exit(status); break; default: fprintf(stderr, "unknown msg %d\n", hdr.type); do_exit(1); } + /* intentionally do not distinguish between _ERROR and _OK, because in case + * of write error, we simply eat the data - no way to report it to the + * other side */ + return WRITE_STDIN_OK; } static void check_child_status(libvchan_t *vchan) @@ -409,36 +450,52 @@ static void check_child_status(libvchan_t *vchan) static void select_loop(libvchan_t *vchan) { fd_set select_set; + fd_set wr_set; int max_fd; int ret; int vchan_fd; sigset_t selectmask; struct timespec zero_timeout = { 0, 0 }; struct timespec select_timeout = { 10, 0 }; + struct buffer stdin_buf; sigemptyset(&selectmask); sigaddset(&selectmask, SIGCHLD); sigprocmask(SIG_BLOCK, &selectmask, NULL); sigemptyset(&selectmask); + buffer_init(&stdin_buf); + /* remember to set back to blocking mode before closing the FD - this may + * be not the only copy and some processes may misbehave when get + * nonblocking FD for input/output + */ + set_nonblock(local_stdin_fd); for (;;) { vchan_fd = libvchan_fd_for_select(vchan); FD_ZERO(&select_set); + FD_ZERO(&wr_set); FD_SET(vchan_fd, &select_set); max_fd = vchan_fd; - if (local_stdout_fd != -1 && libvchan_buffer_space(vchan)) { + if (local_stdout_fd != -1 && + (size_t)libvchan_buffer_space(vchan) > sizeof(struct msg_header)) { FD_SET(local_stdout_fd, &select_set); if (local_stdout_fd > max_fd) max_fd = local_stdout_fd; } if (child_exited && local_stdout_fd == -1) check_child_status(vchan); - if (libvchan_data_ready(vchan) > 0) { + if (local_stdin_fd != -1 && buffer_len(&stdin_buf)) { + FD_SET(local_stdin_fd, &wr_set); + if (local_stdin_fd > max_fd) + max_fd = local_stdin_fd; + } + if ((local_stdin_fd == -1 || buffer_len(&stdin_buf) == 0) && + libvchan_data_ready(vchan) > 0) { /* check for other FDs, but exit immediately */ - ret = pselect(max_fd + 1, &select_set, NULL, NULL, + ret = pselect(max_fd + 1, &select_set, &wr_set, NULL, &zero_timeout, &selectmask); } else - ret = pselect(max_fd + 1, &select_set, NULL, NULL, + ret = pselect(max_fd + 1, &select_set, &wr_set, NULL, &select_timeout, &selectmask); if (ret < 0) { if (errno == EINTR && local_pid > 0) { @@ -456,8 +513,18 @@ static void select_loop(libvchan_t *vchan) } if (FD_ISSET(vchan_fd, &select_set)) libvchan_wait(vchan); + if (buffer_len(&stdin_buf) && + local_stdin_fd != -1 && + FD_ISSET(local_stdin_fd, &wr_set)) { + if (flush_client_data(local_stdin_fd, &stdin_buf) == WRITE_STDIN_ERROR) { + perror("write stdin"); + close(local_stdin_fd); + local_stdin_fd = -1; + } + } while (libvchan_data_ready(vchan)) - handle_vchan_data(vchan); + if (handle_vchan_data(vchan, &stdin_buf) != WRITE_STDIN_OK) + break; if (local_stdout_fd != -1 && FD_ISSET(local_stdout_fd, &select_set))