/*
 * client.c  (C) SOFNEC
 * libwebsock client interface.
 */
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <getopt.h>
#include <string.h>
#include <poll.h>

#include "client.h"

#define MAX_CONTEXT_NUM         10
#define WS_BUFFER_SIZE       1<<16

typedef struct {
  struct libwebsocket_context*   context;
  struct libwebsocket_protocols* protocols;
} CONTEXT_ASSOC;

static CONTEXT_ASSOC ContextAssocList[MAX_CONTEXT_NUM];

static PER_SESSION_DATA* LastConnectedSession;

static unsigned int Options = 0;

static char* ClientCertificateFilePath;
static char* ClientPrivateKeyFilePath;


extern void ws_set_log_level(int level_masks)
{
  lws_set_log_level(level_masks, NULL);
}


static void
context_assoc_clear(CONTEXT_ASSOC* ca)
{
  ca->context   = NULL;
  ca->protocols = NULL;
}

#if 0
static void
context_assoc_list_clear(void)
{
  int i;

  for (i = 0; i < MAX_CONTEXT_NUM; i++) {
    context_assoc_clear(&ContextAssocList[i]);
  }
}
#endif

static CONTEXT_ASSOC*
context_assoc_add(struct libwebsocket_context* context,
                  struct libwebsocket_protocols* protocols)
{
  int i;

  for (i = 0; i < MAX_CONTEXT_NUM; i++) {
    CONTEXT_ASSOC* ca = ContextAssocList + i;
    if (ca->context == NULL) {
      ca->context   = context;
      ca->protocols = protocols;
      return ca;
    }
  }

  return NULL;
}

static CONTEXT_ASSOC* 
context_assoc_find(struct libwebsocket_context* context)
{
  int i;

  for (i = 0; i < MAX_CONTEXT_NUM; i++) {
    CONTEXT_ASSOC* ca = ContextAssocList + i;
    if (ca->context == context)
      return ca;
  }

  return NULL;
}


extern int
ws_set_client_certificate_file_path(char* path)
{
  if (path != NULL && *path != '\0') {
    if (ClientCertificateFilePath == NULL ||
        strcmp(path, ClientCertificateFilePath) != 0) {
      if (ClientCertificateFilePath != NULL) {
        free(ClientCertificateFilePath);
        ClientCertificateFilePath = NULL;
      }

      ClientCertificateFilePath = strdup(path);
      if (ClientCertificateFilePath == NULL)
        return -1;
    }
  }
  else {
    if (ClientCertificateFilePath != NULL) {
      free(ClientCertificateFilePath);
      ClientCertificateFilePath = NULL;
    }
  }

  return 0;
}

extern int
ws_set_client_private_key_file_path(char* path)
{
  if (path != NULL && *path != '\0') {
    if (ClientPrivateKeyFilePath == NULL ||
        strcmp(path, ClientPrivateKeyFilePath) != 0) {
      if (ClientPrivateKeyFilePath != NULL) {
        free(ClientPrivateKeyFilePath);
        ClientPrivateKeyFilePath = NULL;
      }

      ClientPrivateKeyFilePath = strdup(path);
      if (ClientPrivateKeyFilePath == NULL)
        return -1;
    }
  }
  else {
    if (ClientPrivateKeyFilePath != NULL) {
      free(ClientPrivateKeyFilePath);
      ClientPrivateKeyFilePath = NULL;
    }
  }

  return 0;
}

static int
callback_simple_request_response(struct libwebsocket_context *this,
                                 struct libwebsocket *wsi,
                                 enum libwebsocket_callback_reasons reason,
                                 void *user, void *in, size_t len)
{
//#define CA_DIR          "/var/cert/ca/"

  PER_SESSION_DATA* ps = user;

  //ws_print_reason(reason);

  switch (reason) {
  case LWS_CALLBACK_OPENSSL_LOAD_EXTRA_CLIENT_VERIFY_CERTS: 
    {
      int r;
      SSL_CTX* ctx = (SSL_CTX* )user;

#if 0
      fprintf(stderr,
              "LWS_CALLBACK_OPENSSL_LOAD_EXTRA_CLIENT_VERIFY_CERTS: 0x%x, len:%d\r\n",
              (unsigned int )ctx, len);
#endif

#if 0
      if (!SSL_CTX_load_verify_locations(ctx, NULL, CA_DIR))
        fprintf(stderr, "ERROR: load_verify_locations.\n");

      if (SSL_CTX_set_default_verify_paths(ctx) != 1)
        fprintf(stderr, "ERROR: loading default CA file and /or directory.\n");
#endif

      if (ClientCertificateFilePath != NULL && ClientPrivateKeyFilePath != NULL) {
        /* fprintf(stderr, "setting Client certificate fie path.\n"); */

        r = SSL_CTX_use_certificate_file(ctx, ClientCertificateFilePath,
                                         SSL_FILETYPE_PEM);
        if (r != 1)
          fprintf(stderr, "ERROR: loading certificate from file.\n");

        r = SSL_CTX_use_PrivateKey_file(ctx, ClientPrivateKeyFilePath,
                                        SSL_FILETYPE_PEM);
        if (r != 1)
          fprintf(stderr, "ERROR: loading private key from file.\n");
      }
    }
    break;

  case LWS_CALLBACK_CLOSED:
    ps->wsi    = 0;
    ps->opened = FALSE;
    ws_clear_ringbuffer(ps, TRUE);
    break;

  case LWS_CALLBACK_CLIENT_ESTABLISHED:
    libwebsocket_callback_on_writable(this, wsi);
    ps->wsi     = wsi;
    ps->context = this;
    ps->opened  = TRUE;
    ws_clear_ringbuffer(ps, FALSE);
    LastConnectedSession = ps;
    break;

  case LWS_CALLBACK_CLIENT_RECEIVE:
#ifdef DEBUG
    ((char* )in)[len] = '\0';
    fprintf(stderr, "rx: '%s', %d\n", (char* )in, len);
#endif
    if (len == 0)
      break;

    if (ps->ringbuffer[ps->ringbuffer_head].payload)
      free(ps->ringbuffer[ps->ringbuffer_head].payload);

    ps->ringbuffer[ps->ringbuffer_head].payload =
      malloc(LWS_SEND_BUFFER_PRE_PADDING + len +
             LWS_SEND_BUFFER_POST_PADDING);
    ps->ringbuffer[ps->ringbuffer_head].len = len;
    memcpy((char *)ps->ringbuffer[ps->ringbuffer_head].payload +
           LWS_SEND_BUFFER_PRE_PADDING, in, len);
    if (ps->ringbuffer_head == (MAX_MESSAGE_QUEUE - 1))
      ps->ringbuffer_head = 0;
    else
      ps->ringbuffer_head++;

    if (((ps->ringbuffer_head - ps->ringbuffer_tail) %
         MAX_MESSAGE_QUEUE) > (MAX_MESSAGE_QUEUE - 10))
      libwebsocket_rx_flow_control(wsi, 0);

    libwebsocket_callback_on_writable_all_protocol(libwebsockets_get_protocol(wsi));
    break;

  case LWS_CALLBACK_CLIENT_WRITEABLE:
    /* get notified as soon as we can write again */
    libwebsocket_callback_on_writable(this, wsi);

	/*
	 * without at least this delay, we choke the browser
	 * and the connection stalls, despite we now take care about
	 * flow control
	 */

    //usleep(200);
    break;

  default:
    break;
  }

  return 0;
}


static int
wsi_poll(int timeout, struct libwebsocket* wsi, struct pollfd* pfd)
{
  int n, fd;

  fd = libwebsocket_get_socket_fd(wsi);
  if (fd < 0)
    return -1;

  pfd[0].fd = fd;
  pfd[0].events  = POLLIN | POLLERR;
  pfd[0].revents = 0;

  n = poll(pfd, (nfds_t )1, timeout);
  return n;
}

static int
connect_wait(int timeout, struct libwebsocket_context* context,
             struct libwebsocket* wsi)
{
  int rem;
  int count = 0;
  struct timeval tstart;
  struct timeval tnow;

  rem = timeout;
  ws_time(&tstart);

  while (LastConnectedSession == NULL) {
    libwebsocket_service(context, timeout);
    count++;

    if (timeout > 0) {
      ws_time(&tnow);
      rem -= ws_time_diff(&tstart, &tnow);
      if (rem < 0) break;
    }

    //fprintf(stderr, "connect_wait(): %d\r\n", count);
    usleep(10000); // 5 msec.
  }

  return 0;
}

static int
read_wait(int timeout, PER_SESSION_DATA* ps)
{
  int n;
  struct pollfd fd;
  int r = 0;

  if (ps->ringbuffer_tail != ps->ringbuffer_head)
    return 1;

  if (ps->opened == FALSE)
    return -1;

  if (timeout < 0) {
    n = wsi_poll(timeout, ps->wsi, &fd);
    if (n < 0) /* error */
      return -2;
    else if (n == 0) /* timeout */
      return 0;
    else {
      while (ps->opened && (ps->ringbuffer_tail == ps->ringbuffer_head)) {
        r = libwebsocket_service(ps->context, timeout);
        //fprintf(stderr, "read_wait: %d\r\n", r);
        //fflush(stderr);
        if (r != 0)
          return -3;
      }
    }
  }
  else {
    struct timeval tstart;
    struct timeval tnow;
    int rem = timeout;

    ws_time(&tstart);
    n = wsi_poll(timeout, ps->wsi, &fd);
    if (n < 0) /* error */
      return -2;
    else if (n == 0) /* timeout */
      return 0;
    else {
      ws_time(&tnow);
      rem -= ws_time_diff(&tstart, &tnow);
      tstart = tnow;
      if (rem > 0) {
        while (ps->opened && (ps->ringbuffer_tail == ps->ringbuffer_head)) {
          r = libwebsocket_service(ps->context, rem);
          //fprintf(stderr, "read_wait libwebsocket_service: %d\r\n", r);
          //fflush(stderr);
          ws_time(&tnow);

          if (r != 0)
            return -3;

          rem -= ws_time_diff(&tstart, &tnow);
          if (rem <= 0)
            break;

          tstart = tnow;
        }
      }
    }
  }

  return r;
}

extern int
ws_read(int timeout, PER_SESSION_DATA* ps, char* buf, int buf_size)
{
  int r;

  r = read_wait(timeout, ps);
  //fprintf(stderr, "ws_read: read_wait: %d\r\n", r);
  if (r < 0) return r;

  r = 0;
  if (ps->ringbuffer_tail != ps->ringbuffer_head) {
    int len = ps->ringbuffer[ps->ringbuffer_tail].len;
    memcpy(buf,
           ps->ringbuffer[ps->ringbuffer_tail].payload + LWS_SEND_BUFFER_PRE_PADDING,
           len);
    buf[len] = '\0';
    r = len;

    if (ps->ringbuffer_tail == (MAX_MESSAGE_QUEUE - 1))
      ps->ringbuffer_tail = 0;
    else
      ps->ringbuffer_tail++;
    
    if (((ps->ringbuffer_head - ps->ringbuffer_tail) %
         MAX_MESSAGE_QUEUE) < (MAX_MESSAGE_QUEUE - 15))
      libwebsocket_rx_flow_control(ps->wsi, 1);
  }

  return r;
}

extern int
ws_read_ready(int timeout, PER_SESSION_DATA* ps, void** pp)
{
  int r;

  r = read_wait(timeout, ps);
  if (r < 0) return r;

  r = 0;
  if (ps->ringbuffer_tail != ps->ringbuffer_head) {
    r = ps->ringbuffer[ps->ringbuffer_tail].len;
    *pp = ps->ringbuffer[ps->ringbuffer_tail].payload + LWS_SEND_BUFFER_PRE_PADDING;
  }

  return r;
}

extern int
ws_read_finish(PER_SESSION_DATA* ps)
{
  int r;

  r = 0;
  if (ps->ringbuffer_tail != ps->ringbuffer_head) {
    if (ps->ringbuffer_tail == (MAX_MESSAGE_QUEUE - 1))
      ps->ringbuffer_tail = 0;
    else
      ps->ringbuffer_tail++;
    
    if (((ps->ringbuffer_head - ps->ringbuffer_tail) %
         MAX_MESSAGE_QUEUE) < (MAX_MESSAGE_QUEUE - 15))
      if (ps->opened) {
        libwebsocket_rx_flow_control(ps->wsi, 1);
      }
  }

  return r;
}


extern int
ws_write_text(PER_SESSION_DATA* ps, char* s, int len)
{
  int r;

  if (len < 0) return 0;
  if (ps->opened == FALSE) return -1;

  unsigned char buf[LWS_SEND_BUFFER_PRE_PADDING + len + 1 +
                    LWS_SEND_BUFFER_POST_PADDING];

  strncpy((char* )&buf[LWS_SEND_BUFFER_PRE_PADDING], s, len);
  buf[LWS_SEND_BUFFER_PRE_PADDING + len] = '\0';
  r = libwebsocket_write(ps->wsi,
                         &buf[LWS_SEND_BUFFER_PRE_PADDING], len,
                         Options | LWS_WRITE_TEXT);
  if (r < 0)
    return r;
  else
    return 0;
}


extern int
ws_write_binary(PER_SESSION_DATA* ps, unsigned char* s, int len)
{
  int r;

  if (len < 0) return 0;
  if (ps->opened == FALSE) return -1;

  unsigned char buf[LWS_SEND_BUFFER_PRE_PADDING + len + 1 +
                    LWS_SEND_BUFFER_POST_PADDING];

  memcpy((char* )&buf[LWS_SEND_BUFFER_PRE_PADDING], s, len);
  buf[LWS_SEND_BUFFER_PRE_PADDING + len] = '\0';
  r = libwebsocket_write(ps->wsi,
                         &buf[LWS_SEND_BUFFER_PRE_PADDING], len,
                         Options | LWS_WRITE_BINARY);
  if (r < 0)
    return r;
  else
    return 0;
}

extern int
ws_write_pong(PER_SESSION_DATA* ps, char* s, int len)
{
  int r;
  unsigned int options;

  if (len < 0) return 0;
  if (ps->opened == FALSE) return -1;

  unsigned char buf[LWS_SEND_BUFFER_PRE_PADDING + len + 1 +
                    LWS_SEND_BUFFER_POST_PADDING];

  options = LWS_WRITE_PONG;
  if (len > 0) {
    memcpy((char* )&buf[LWS_SEND_BUFFER_PRE_PADDING], s, len);
  }
  else {
    ;
  }

  buf[LWS_SEND_BUFFER_PRE_PADDING + len] = '\0';
  r = libwebsocket_write(ps->wsi, &buf[LWS_SEND_BUFFER_PRE_PADDING], len, options);
  if (r < 0)
    return r;
  else
    return 0;
}

extern PER_SESSION_DATA*
ws_connect(int timeout, struct libwebsocket_context* context,
           const char* sub_protocol_name,
           const char*address, int port, int use_ssl,
           const char* path, const char* host, const char* origin, int ietf_version)
{
  int r, i, proto_index;
  struct libwebsocket *wsi;
  struct libwebsocket_protocols *protocols;

  CONTEXT_ASSOC* ca = context_assoc_find(context);
  if (ca == NULL) {
    return NULL;
  }

  protocols = ca->protocols;

  proto_index = -1;
  for (i = 0; protocols[i].callback != NULL; i++) {
    if (protocols[i].name != NULL &&
        strcmp(sub_protocol_name, protocols[i].name) == 0) {
      proto_index = i;
      break;
    }
  }

  if (proto_index < 0)
    return NULL;

  LastConnectedSession = NULL;
  wsi = libwebsocket_client_connect(context,
                                    address, port, use_ssl, path, host, origin,
                                    protocols[proto_index].name,
                                    ietf_version);
  if (wsi == NULL) {
#ifdef DEBUG
    fprintf(stderr, "libwebsocket connect failed. (wsi)\n");
#endif
    return NULL;
  }

  r = connect_wait(timeout, context, wsi);
  if (r < 0) {
#ifdef DEBUG
    fprintf(stderr, "libwebsocket connect failed.\n");
#endif
    return NULL;
  }

  return LastConnectedSession;
}

extern int
ws_close(PER_SESSION_DATA* psd)
{
  int r;

  if (psd->opened == FALSE) return 0;

  unsigned char buf[LWS_SEND_BUFFER_PRE_PADDING + 1 + LWS_SEND_BUFFER_POST_PADDING];

  buf[LWS_SEND_BUFFER_PRE_PADDING] = '\0';
  r = libwebsocket_write(psd->wsi,
                         &buf[LWS_SEND_BUFFER_PRE_PADDING], 0, LWS_WRITE_CLOSE);
  if (r < 0)
    return r;
  else
    return 0;

  psd->opened = FALSE;
  return 0;
}



extern int
ws_is_close(PER_SESSION_DATA* psd, int timeout)
{
  if (psd->opened == FALSE)
    return TRUE;
  else {
    libwebsocket_service(psd->context, timeout);
    if (psd->opened == FALSE)
      return TRUE;
    else
      return FALSE;
  }
}

extern int
ws_create_context(int sub_protocol_num,
                  const char sub_protocols[][WS_SUB_PROTOCOL_NAME_MAX_SIZE],
                  int default_protocol_index,
                  struct libwebsocket_context** r_context)
{
  int i;
  struct libwebsocket_protocols* protocols;
  struct libwebsocket_context* context;
	struct lws_context_creation_info info;

  protocols = (struct libwebsocket_protocols* )malloc(sizeof(struct libwebsocket_protocols) * (sub_protocol_num + 1));
  if (protocols == 0) return -1;

  for (i = 0; i < sub_protocol_num; i++) {
    const char* name = sub_protocols[i];
    protocols[i].name      = (name && strlen(name) > 0) ? strdup(name) : NULL;
    protocols[i].callback  = callback_simple_request_response;
    protocols[i].per_session_data_size = sizeof(PER_SESSION_DATA);
    protocols[i].rx_buffer_size = WS_BUFFER_SIZE;
  }
  protocols[sub_protocol_num].name      = NULL;
  protocols[sub_protocol_num].callback  = NULL;
  protocols[sub_protocol_num].per_session_data_size = 0;

	memset(&info, 0, sizeof(info));
	info.port       = CONTEXT_PORT_NO_LISTEN;
	info.protocols  = protocols;
	info.extensions = libwebsocket_get_internal_extensions();
	info.gid = -1;
	info.uid = -1;


  context = libwebsocket_create_context(&info);
  *r_context = context;
  if (context == NULL) {
    free(protocols);
    return -1;
  }
  else {
    if (context_assoc_add(context, protocols) == NULL) {
      free(protocols);
      return -2;
    }

    return 0;
  }
}

extern int
ws_delete_context(struct libwebsocket_context* context)
{
  struct libwebsocket_protocols* p;

  CONTEXT_ASSOC* ca = context_assoc_find(context);
  if (ca == NULL)
    return -1;

  p = ca->protocols;
  context_assoc_clear(ca);

  if (p && strcmp(p[0].name, "http-only") == 0) {
    fprintf(stderr, "This context is server side.\n");
    return -1;
  }

  libwebsocket_context_destroy(context);

  if (p) {
    int i;
    for (i = 0; p[i].callback != NULL; i++) {
      if (p[i].name)
        free((char* )p[i].name);
    }
    free(p);
  }

  return 0;
}

extern const char*
ws_session_protocol(PER_SESSION_DATA* ps)
{
  const struct libwebsocket_protocols* protocol;

  if (ps->opened == FALSE) return NULL;

  protocol = libwebsockets_get_protocol(ps->wsi);
  return protocol->name;
}


#ifdef TEST

static struct option options[] = {
  { "help",       no_argument,       NULL, 'h' },
  { "port",       required_argument, NULL, 'p' },
  { "ssl",        no_argument,       NULL, 's' },
  { "killmask",   no_argument,       NULL, 'k' },
  { "version",    required_argument, NULL, 'v' },
  { "undeflated", no_argument,       NULL, 'u' },
  { "nomux",      no_argument,       NULL, 'n' },
  { NULL, 0, 0, 0 }
};

int main(int argc, char **argv)
{
  int r = 0;
  int port = 8080;
  int use_ssl = 0;
  const char *address;
  int ietf_version = -1; /* latest */
  struct libwebsocket_context* context;

  const char sub_protocols[5][WS_SUB_PROTOCOL_NAME_MAX_SIZE];

  fprintf(stderr, "libwebsockets test client\n");

  if (argc < 2)
    goto usage;

  int n = 0;
  while (n >= 0) {
    n = getopt_long(argc, argv, "nuv:khsp:", options, NULL);
    if (n < 0)
      continue;

    switch (n) {
    case 's':
      use_ssl = 2; /* 2 = allow selfsigned */
      break;
    case 'p':
      port = atoi(optarg);
      break;
    case 'k':
      Options = LWS_WRITE_CLIENT_IGNORE_XOR_MASK;
      break;
    case 'v':
      ietf_version = atoi(optarg);
      break;
    case 'u':
      //deny_deflate = 1;
      break;
    case 'n':
      //deny_mux = 1;
      break;
    case 'h':
      goto usage;
    }
  }

  if (optind >= argc)
    goto usage;

  address = argv[optind];

  strcpy((char* )sub_protocols[0], "simple-rr");
  strcpy((char* )sub_protocols[1], "foo");

  r = ws_create_context(2, sub_protocols, 0, &context);
  if (r < 0) {
    fprintf(stderr, "Creating libwebsocket context failed\n");
    return -1;
  }

  int connect_timeout = 10000; // msec.
  char* path = "/echo";
  const char* sub_protocol = "simple-rr";
  PER_SESSION_DATA* ps;
  ps = ws_connect(connect_timeout, context, sub_protocol,
                  address, port, use_ssl, path,
                  argv[optind], argv[optind], ietf_version);
  if (ps == NULL) {
    fprintf(stderr, "libwebsocket connect failed\n");
    return -1;
  }

  fprintf(stderr, "connections opened\n");

  char buf[1024];

  int read_timeout = 5000;
  int counter = 0;
  while (counter < 10 && ps->opened) {
    //fprintf(stderr, "write\n");
    sprintf(buf, "Hello, World %d --", counter++);
    r = ws_write_text(ps, buf, strlen(buf));
    if (r < 0)
      break;

    //fprintf(stderr, "read\n");

    r = ws_read(read_timeout, ps, buf, sizeof(buf));
    if (r < 0)
      break;
    else if (r == 0)
      fprintf(stderr, "receive timeout\n");
    else
      fprintf(stderr, "receive: %s, %d\n", buf, r);
  }
  
  ws_close(ps);

  /*
    sleep a while because libwebsocket_close_and_free_session() set socket close timer
    after 5 seconds, and server sends Close packet to client.
  */
  usleep(2000000);
  ws_delete_context(context);
  fprintf(stderr, "Exiting\n");
  return 0;

usage:
  fprintf(stderr, "Usage: client <server address> [--port=<p>] [--ssl] [-k] [-v <ver>]\n");
  return 1;
}

#endif /* TEST */
