/*
 * rtp_decoder.c
 *
 * decoder structures and functions for SRTP pcap decoder
 *
 * Example:
 * $ wget --no-check-certificate https://raw.githubusercontent.com/gteissier/srtp-decrypt/master/marseillaise-srtp.pcap
 * $ ./test/rtp_decoder -a -t 10 -e 128 -b aSBrbm93IGFsbCB5b3VyIGxpdHRsZSBzZWNyZXRz \
 *    < ~/marseillaise-srtp.pcap | text2pcap -t "%M:%S." -u 10000,10000 - - > ./marseillaise-rtp.pcap
 *
 * There is also a different way of setting up key size and tag size
 * based upon RFC 4568 crypto suite specification, i.e.:
 *
 * $ ./test/rtp_decoder -s AES_CM_128_HMAC_SHA1_80 -b aSBrbm93IGFsbCB5b3VyIGxpdHRsZSBzZWNyZXRz ...
 *
 * Audio can be extracted using extractaudio utility from the RTPproxy
 * package:
 *
 * $ extractaudio -A ./marseillaise-rtp.pcap ./marseillaise-out.wav
 *
 * Bernardo Torres <bernardo@torresautomacao.com.br>
 *
 * Some structure and code from https://github.com/gteissier/srtp-decrypt
 */
/*
 *	
 * Copyright (c) 2001-2017 Cisco Systems, Inc.
 * All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 
 *   Redistributions of source code must retain the above copyright
 *   notice, this list of conditions and the following disclaimer.
 * 
 *   Redistributions in binary form must reproduce the above
 *   copyright notice, this list of conditions and the following
 *   disclaimer in the documentation and/or other materials provided
 *   with the distribution.
 * 
 *   Neither the name of the Cisco Systems, Inc. nor the names of its
 *   contributors may be used to endorse or promote products derived
 *   from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
 * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
 * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
 * OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 */
#include "getopt_s.h"       /* for local getopt()  */
#include <assert.h>         /* for assert()  */

#include <pcap.h>
#include "rtp_decoder.h"
#include "util.h"

#define MAX_KEY_LEN      96
#define MAX_FILTER 256

struct srtp_crypto_suite {
  const char *can_name;
  int key_size;
  int tag_size;
};

static struct srtp_crypto_suite srtp_crypto_suites[] = {
  {.can_name = "AES_CM_128_HMAC_SHA1_32", .key_size = 128, .tag_size = 4},
#if 0
  {.can_name = "F8_128_HMAC_SHA1_32", .key_size = 128, .tag_size = 4},
#endif
  {.can_name = "AES_CM_128_HMAC_SHA1_32", .key_size = 128, .tag_size = 4},
  {.can_name = "AES_CM_128_HMAC_SHA1_80", .key_size = 128, .tag_size = 10},
  {.can_name = NULL}
};

int
main (int argc, char *argv[]) {
  char errbuf[PCAP_ERRBUF_SIZE];
  bpf_u_int32 pcap_net = 0;
  pcap_t *pcap_handle;
#if BEW
  struct sockaddr_in local;
#endif
  srtp_sec_serv_t sec_servs = sec_serv_none;
  int c;
  struct srtp_crypto_suite scs, *i_scsp;
  scs.key_size = 128;
  scs.tag_size = 8;
  int gcm_on = 0;
  char *input_key = NULL;
  int b64_input = 0;
  char key[MAX_KEY_LEN];
  struct bpf_program fp;
  char filter_exp[MAX_FILTER] = "";
  rtp_decoder_t dec;
  srtp_policy_t policy;
  srtp_err_status_t status;
  int len;
  int expected_len;
  int do_list_mods = 0;

  fprintf(stderr, "Using %s [0x%x]\n", srtp_get_version_string(), srtp_get_version());

  /* initialize srtp library */
  status = srtp_init();
  if (status) {
    fprintf(stderr, "error: srtp initialization failed with error code %d\n", status);
    exit(1);
  }

  /* check args */
  while (1) {
    c = getopt_s(argc, argv, "b:k:gt:ae:ld:f:s:");
    if (c == -1) {
      break;
    }
    switch (c) {
	case 'b':
      b64_input = 1;
      /* fall thru */
    case 'k':
      input_key = optarg_s;
      break;
    case 'e':
      scs.key_size = atoi(optarg_s);
      if (scs.key_size != 128 && scs.key_size != 256) {
        fprintf(stderr, "error: encryption key size must be 128 or 256 (%d)\n", scs.key_size);
        exit(1);
      }
      input_key = malloc(scs.key_size);
      sec_servs |= sec_serv_conf;
      break;
    case 't':
      scs.tag_size = atoi(optarg_s);
      break;
    case 'a':
      sec_servs |= sec_serv_auth;
      break;
    case 'g':
      gcm_on = 1;
      sec_servs |= sec_serv_auth;
      break;
    case 'd':
      status = srtp_crypto_kernel_set_debug_module(optarg_s, 1);
      if (status) {
        fprintf(stderr, "error: set debug module (%s) failed\n", optarg_s);
        exit(1);
      }
      break;
    case 'f':
      if(strlen(optarg_s) > MAX_FILTER){
        fprintf(stderr, "error: filter bigger than %d characters\n", MAX_FILTER);
        exit(1);
      }
      fprintf(stderr, "Setting filter as %s\n", optarg_s);
      strcpy(filter_exp, optarg_s);
      break;
    case 'l':
      do_list_mods = 1;
      break;
    case 's':
      for (i_scsp = &srtp_crypto_suites[0]; i_scsp->can_name != NULL; i_scsp++) {
        if (strcasecmp(i_scsp->can_name, optarg_s) == 0) {
            break;
        }
      }
      if (i_scsp->can_name == NULL) {
        fprintf(stderr, "Unknown/unsupported crypto suite name %s\n", optarg_s);
        exit(1);
      }
      scs = *i_scsp;
      input_key = malloc(scs.key_size);
      sec_servs |= sec_serv_conf | sec_serv_auth;
      break;
    default:
      usage(argv[0]);
    }
  }

  if (gcm_on && scs.tag_size != 8 && scs.tag_size != 16) {
    fprintf(stderr, "error: GCM tag size must be 8 or 16 (%d)\n", scs.tag_size);
    //exit(1);
  }

  if (do_list_mods) {
    status = srtp_crypto_kernel_list_debug_modules();
    if (status) {
      fprintf(stderr, "error: list of debug modules failed\n");
	  exit(1);
    }
    return 0;
  }

  if ((sec_servs && !input_key) || (!sec_servs && input_key)) {
    /*
     * a key must be provided if and only if security services have
     * been requested
     */
	  if(input_key == NULL){
		  fprintf(stderr, "key not provided\n");
	  }
	  if(!sec_servs){
		  fprintf(stderr, "no secservs\n");
	  }
    fprintf(stderr, "provided\n");
    usage(argv[0]);
  }



  /* report security services selected on the command line */
  fprintf(stderr, "security services: ");
  if (sec_servs & sec_serv_conf)
    fprintf(stderr, "confidentiality ");
  if (sec_servs & sec_serv_auth)
    fprintf(stderr, "message authentication");
  if (sec_servs == sec_serv_none)
    fprintf(stderr, "none");
  fprintf(stderr, "\n");

  /* set up the srtp policy and master key */
  if (sec_servs) {
    /*
     * create policy structure, using the default mechanisms but
     * with only the security services requested on the command line,
     * using the right SSRC value
     */
    switch (sec_servs) {
    case sec_serv_conf_and_auth:
      if (gcm_on) {
#ifdef OPENSSL
	switch (scs.key_size) {
	case 128:
	  srtp_crypto_policy_set_aes_gcm_128_8_auth(&policy.rtp);
	  srtp_crypto_policy_set_aes_gcm_128_8_auth(&policy.rtcp);
	  break;
	case 256:
	  srtp_crypto_policy_set_aes_gcm_256_8_auth(&policy.rtp);
	  srtp_crypto_policy_set_aes_gcm_256_8_auth(&policy.rtcp);
	  break;
	}
#else
	fprintf(stderr, "error: GCM mode only supported when using the OpenSSL crypto engine.\n");
	return 0;
#endif
      } else {
	switch (scs.key_size) {
	case 128:
          srtp_crypto_policy_set_rtp_default(&policy.rtp);
          srtp_crypto_policy_set_rtcp_default(&policy.rtcp);
	  break;
	case 256:
          srtp_crypto_policy_set_aes_cm_256_hmac_sha1_80(&policy.rtp);
          srtp_crypto_policy_set_rtcp_default(&policy.rtcp);
	  break;
	}
      }
      break;
    case sec_serv_conf:
      if (gcm_on) {
	  fprintf(stderr, "error: GCM mode must always be used with auth enabled\n");
	  return -1;
      } else {
	switch (scs.key_size) {
	case 128:
          srtp_crypto_policy_set_aes_cm_128_null_auth(&policy.rtp);
          srtp_crypto_policy_set_rtcp_default(&policy.rtcp);
	  break;
	case 256:
          srtp_crypto_policy_set_aes_cm_256_null_auth(&policy.rtp);
          srtp_crypto_policy_set_rtcp_default(&policy.rtcp);
	  break;
	}
      }
      break;
    case sec_serv_auth:
      if (gcm_on) {
#ifdef OPENSSL
	switch (scs.key_size) {
	case 128:
	  srtp_crypto_policy_set_aes_gcm_128_8_only_auth(&policy.rtp);
	  srtp_crypto_policy_set_aes_gcm_128_8_only_auth(&policy.rtcp);
	  break;
	case 256:
	  srtp_crypto_policy_set_aes_gcm_256_8_only_auth(&policy.rtp);
	  srtp_crypto_policy_set_aes_gcm_256_8_only_auth(&policy.rtcp);
	  break;
	}
#else
	printf("error: GCM mode only supported when using the OpenSSL crypto engine.\n");
	return 0;
#endif
      } else {
        srtp_crypto_policy_set_null_cipher_hmac_sha1_80(&policy.rtp);
        srtp_crypto_policy_set_rtcp_default(&policy.rtcp);
      }
      break;
    default:
      fprintf(stderr, "error: unknown security service requested\n");
      return -1;
    }

    policy.key  = (uint8_t *) key;
    policy.ekt  = NULL;
    policy.next = NULL;
    policy.window_size = 128;
    policy.allow_repeat_tx = 0;
    policy.rtp.sec_serv = sec_servs;
    policy.rtcp.sec_serv = sec_servs; //sec_serv_none;  /* we don't do RTCP anyway */
    fprintf(stderr, "setting tag len %d\n", scs.tag_size);
    policy.rtp.auth_tag_len = scs.tag_size;

    if (gcm_on && scs.tag_size != 8) {
      fprintf(stderr, "setted tag len %d\n", scs.tag_size);
      policy.rtp.auth_tag_len = scs.tag_size;
    }

    /*
     * read key from hexadecimal or base64 on command line into an octet string
     */
    if (b64_input) {
      int pad;
      expected_len = policy.rtp.cipher_key_len*4/3;
      len = base64_string_to_octet_string(key, &pad, input_key, expected_len);
      if (pad != 0) {
        fprintf(stderr, "error: padding in base64 unexpected\n");
        exit(1);
      }
    } else {
      expected_len = policy.rtp.cipher_key_len*2;
      len = hex_string_to_octet_string(key, input_key, expected_len);
    }
    /* check that hex string is the right length */
    if (len < expected_len) {
      fprintf(stderr,
	      "error: too few digits in key/salt "
	      "(should be %d digits, found %d)\n",
	      expected_len, len);
      exit(1);
    }
    if (strlen(input_key) > policy.rtp.cipher_key_len*2) {
      fprintf(stderr,
	      "error: too many digits in key/salt "
	      "(should be %d hexadecimal digits, found %u)\n",
	      policy.rtp.cipher_key_len*2, (unsigned)strlen(input_key));
      exit(1);
    }

    fprintf(stderr, "set master key/salt to %s/", octet_string_hex_string(key, 16));
    fprintf(stderr, "%s\n", octet_string_hex_string(key+16, 14));

  } else {
      fprintf(stderr, "error: neither encryption or authentication were selected");
      exit(1);
  }

	pcap_handle = pcap_open_offline("-", errbuf);

	if (!pcap_handle) {
	    fprintf(stderr, "libpcap failed to open file '%s'\n", errbuf);
	    exit(1);
	}
	assert(pcap_handle != NULL);
	if ((pcap_compile(pcap_handle, &fp, filter_exp, 1, pcap_net)) == -1) {
	    fprintf(stderr, "Couldn't parse filter %s: %s\n", filter_exp,
	        pcap_geterr(pcap_handle));
	    return (2);
	}
	if (pcap_setfilter(pcap_handle, &fp) == -1) {
	  fprintf(stderr, "couldn't install filter %s: %s\n", filter_exp,
	      pcap_geterr(pcap_handle));
	  return (2);
	}
	dec = rtp_decoder_alloc();
	if (dec == NULL) {
	  fprintf(stderr, "error: malloc() failed\n");
	  exit(1);
	}
	fprintf(stderr, "Starting decoder\n");
	rtp_decoder_init(dec, policy);

	pcap_loop(pcap_handle, 0, rtp_decoder_handle_pkt, (u_char *)dec);

	rtp_decoder_deinit_srtp(dec);
	rtp_decoder_dealloc(dec);

  status = srtp_shutdown();
  if (status) {
    fprintf(stderr, "error: srtp shutdown failed with error code %d\n", status);
    exit(1);
  }

  return 0;
}


void
usage(char *string) {

  fprintf(stderr, "usage: %s [-d <debug>]* [[-k][-b] <key> [-a][-e]]\n"
	 "or     %s -l\n"
	 "where  -a use message authentication\n"
	 "       -e <key size> use encryption (use 128 or 256 for key size)\n"
	 "       -g Use AES-GCM mode (must be used with -e)\n"
	 "       -t <tag size> Tag size to use (in GCM mode use 8 or 16)\n"
	 "       -k <key>  sets the srtp master key given in hexadecimal\n"
	 "       -b <key>  sets the srtp master key given in base64\n"
	 "       -l list debug modules\n"
	 "       -f \"<pcap filter>\" to filter only the desired SRTP packets\n"
	 "       -d <debug> turn on debugging for module <debug>\n"
	 "       -s \"<srtp-crypto-suite>\" to set both key and tag size based\n"
	 "          on RFC4568-style crypto suite specification\n",
	 string, string);
  exit(1);

}

rtp_decoder_t
rtp_decoder_alloc(void) {
  return (rtp_decoder_t)malloc(sizeof(rtp_decoder_ctx_t));
}

void
rtp_decoder_dealloc(rtp_decoder_t rtp_ctx) {
  free(rtp_ctx);
}

srtp_err_status_t
rtp_decoder_init_srtp(rtp_decoder_t decoder, unsigned int ssrc) {
  decoder->policy.ssrc.value = htonl(ssrc);
  return srtp_create(&decoder->srtp_ctx, &decoder->policy);
}

int
rtp_decoder_deinit_srtp(rtp_decoder_t decoder) {
  return srtp_dealloc(decoder->srtp_ctx);
}

int
rtp_decoder_init(rtp_decoder_t dcdr, srtp_policy_t policy) {
	dcdr->rtp_offset = DEFAULT_RTP_OFFSET;
	dcdr->srtp_ctx = NULL;
	dcdr->start_tv.tv_usec = 0;
	dcdr->start_tv.tv_sec = 0;
	dcdr->frame_nr = -1;
  dcdr->policy = policy;
	dcdr->policy.ssrc.type  = ssrc_specific;
	return 0;
}

/*
 * decodes key as base64
 */

void hexdump(const void *ptr, size_t size) {
  int i, j;
  const unsigned char *cptr = ptr;

  for (i = 0; i < size; i += 16) {
    fprintf(stdout, "%04x ", i);
    for (j = 0; j < 16 && i+j < size; j++) {
      fprintf(stdout, "%02x ", cptr[i+j]);
    }
    fprintf(stdout, "\n");
  }
}

void
rtp_decoder_handle_pkt(u_char *arg, const struct pcap_pkthdr *hdr,
	const u_char *bytes) {
  rtp_decoder_t dcdr = (rtp_decoder_t)arg;
  int pktsize;
  struct timeval delta;
  int octets_recvd;
  srtp_err_status_t status;
  dcdr->frame_nr++;

  if ((dcdr->start_tv.tv_sec == 0) && (dcdr->start_tv.tv_usec == 0)) {
    dcdr->start_tv = hdr->ts;
  }

  if (hdr->caplen < dcdr->rtp_offset) {
    return;
  }
  const void *rtp_packet = bytes + dcdr->rtp_offset;

  memcpy((void *)&dcdr->message, rtp_packet, hdr->caplen - dcdr->rtp_offset);
  pktsize = hdr->caplen - dcdr->rtp_offset;
  octets_recvd = pktsize;

  if (octets_recvd == -1) {
    return;
  }

  /* verify rtp header */
  if (dcdr->message.header.version != 2) {
    return;
  }
  if(dcdr->srtp_ctx == NULL) {
    status = rtp_decoder_init_srtp(dcdr, dcdr->message.header.ssrc);
    if (status) {
      exit(1);
    }
  }
  status = srtp_unprotect(dcdr->srtp_ctx, &dcdr->message, &octets_recvd);
  if (status) {
    return;
  }
  timersub(&hdr->ts, &dcdr->start_tv, &delta);
  fprintf(stdout, "%02ld:%02ld.%06ld\n", delta.tv_sec/60, delta.tv_sec%60, (long)delta.tv_usec);
  hexdump(&dcdr->message, octets_recvd);
}

void rtp_print_error(srtp_err_status_t status, char *message) {
    fprintf(stderr,
            "error: %s %d%s\n", message, status,
            status == srtp_err_status_replay_fail ? " (replay check failed)" :
            status == srtp_err_status_bad_param ? " (bad param)" :
            status == srtp_err_status_no_ctx ? " (no context)" :
            status == srtp_err_status_cipher_fail ? " (cipher failed)" :
            status == srtp_err_status_key_expired ? " (key expired)" :
            status == srtp_err_status_auth_fail ? " (auth check failed)" : "");
}