#!/usr/bin/env python
#
# cigetcert gets an X.509 certificate from an SP using the ECP profile.
# Optionally it can also get a grid proxy certificate and/or transfer
#   the proxy to MyProxy.
#
# Acronyms used:
#  SP - Service Provider (cilogon)
#  IdP - Identity Provider
#  SAML - Security Assertion Markup Language
#  ECP - Enhanced Client or Proxy SAML Profile
#
# Nonstandard python libraries required:
#  m2crypto
#  pyOpenSSL
#  python-kerberos
#  python-lxml

# Except where noted, this source file is Copyright (c) 2015-2016, FERMI
#   NATIONAL ACCELERATOR LABORATORY.  All rights reserved. 
#
# For details of the Fermitools (BSD) license see COPYING or
#  http://fermitools.fnal.gov/about/terms.html
#
# Author: Dave Dykstra dwd@fnal.gov


prog = "cigetcert"
version = "1.0"

import sys
import os
import re
from lxml import etree
import httplib
import socket
import urllib
import urllib2
import urlparse
import cookielib
import kerberos
import getpass
import base64
import string
import random
import time
from M2Crypto import SSL, X509, EVP, RSA, ASN1, m2
from OpenSSL import crypto

import shlex
from optparse import OptionParser

defaults = {
    "spurl" : "https://ecp.cilogon.org/secure/getcert",
    "idplisturl" : "https://cilogon.org/include/ecpidps.txt",
    "cafile" : "/etc/pki/tls/cert.pem",
    "capath" : "/etc/grid-security/certificates"
}

# these are global
options = None
showprogress = False

def usage(parser, msg):
    print >> sys.stderr, prog + ": " + msg + '\n'
    parser.print_help(sys.stderr)
    sys.exit(2)

def fatal(msg, code=1):
    global options
    if (options == None) or not options.quiet:
	if showprogress:
	    print >>sys.stderr
	print >> sys.stderr, prog + ": " + msg + '\n'
    sys.exit(code)

# print exception type name and contents after fatal error message
def efatal(msg, e, code=1):
    fatal(msg + ': ' + type(e).__name__ + ': ' + str(e), code)

# this is from http://stackoverflow.com/questions/27227221/error-when-using-xml-api-findall-in-with-python-2-6
def resolve_xpath(xpath, namespace):
    result = xpath
    for short_name, url in namespace.items():
        result = re.sub(r'\b' + short_name + ':', '{' + url + '}', result)
    return result

# this is from http://python-notes.curiousefficiency.org/en/latest/python_kerberos.html
def www_auth(handle):
    auth_fields = {}
    for field in handle.info().getheader("www-authenticate", "").split(","):
	field = field.strip()
	space = field.find(" ")
	if space == -1:
	    space = len(field)
	kind = field[0:space]
	details = field[space+1:]
	auth_fields[kind.lower()] = details.strip()
    return auth_fields

# function from http://stackoverflow.com/questions/4407539/python-how-to-make-an-option-to-be-required-in-optparse
def checkRequiredOptions(parser):
    global options
    missing_options = []
    for option in parser.option_list:
        if re.search('\(required\)$', option.help) and eval('options.' + option.dest) == None:
            missing_options.extend(option._long_opts)
    if len(missing_options) > 0:
	usage(parser, "Missing required parameters: " + str(missing_options))

# M2Crypto's X509_Name as_text() returns comma-separated list, so convert
#  that to conventional format with slashes.
def x509name_to_str(name):
    return '/' + name.as_text().replace(', ','/')


# Make a wrapper so a SSL Connection object can be opened as a file
# this is mostly from 
# http://git.ganeti.org/?p=ganeti.git;a=commitdiff;h=beba56ae8;hp=70c815118f7f8bf151044cb09868d1e3d7a63ac8
_SslSocketWrapperSkipClose = True
class _SslSocketWrapper(object):
    def __init__(self, sock):
	self._sock = sock
    def __getattr__(self, name):
	# forward everything to underlying socket
	return getattr(self._sock, name)
    def makefile(self, mode, bufsize):
	return socket._fileobject(self._sock, mode, bufsize)
    def close(self):
	global _SslSocketWrapperSkipClose
	if _SslSocketWrapperSkipClose:
	    # avoid premature closure.
	    # I don't like leaving the file descriptors open but without this
	    #  it doesn't work.  They'll all get closed at program completion.
	    return
	return self._sock.close()

# validate a certificate on an HTTPS connection with M2Crypto
class CertValidatingHTTPSConnection(httplib.HTTPConnection):
    default_port = httplib.HTTPS_PORT

    def __init__(self, host, port=None, key_file=None, cert_file=None,
	    cert_chain_file=None, cafile=None, capath=None, strict=None,
	    **kwargs):
        httplib.HTTPConnection.__init__(self, host, port, strict, **kwargs)
	self.host = host
        self.key_file = key_file
        self.cert_file = cert_file
        self.cert_chain_file = cert_chain_file
	self.cafile = cafile
        self.capath = capath

    def connect(self):
	# the example used for M2Crypto connections was mostly 
	#   http://www.heikkitoivonen.net/blog/2008/10/14/ssl-in-python-26/
	context = SSL.Context('tlsv1')
	if (self.cert_file != None) or (self.key_file != None):
	    context.load_cert(self.cert_file, self.key_file)
	if self.cert_chain_file != None:
	    context.load_cert_chain(self.cert_chain_file)
	context.set_verify(SSL.verify_peer | SSL.verify_fail_if_no_peer_cert, depth=9)
	if context.load_verify_locations(self.cafile, self.capath) != 1:
	    raise RuntimeError('Could not load verify locations ' + \
	    		str(self.cafile) + ' ' + str(self.capath))
	sslconn = SSL.Connection(context)
	sslconn.connect((self.host, self.port))
	self.sock = _SslSocketWrapper(sslconn)

class VerifiedHTTPSHandler(urllib2.HTTPSHandler):
    def __init__(self, **kwargs):
        urllib2.HTTPSHandler.__init__(self)
        self._connection_args = kwargs

    def https_open(self, req):
        def http_class_wrapper(host, **kwargs):
            full_kwargs = dict(self._connection_args)
            full_kwargs.update(kwargs)
            return CertValidatingHTTPSConnection(host, **full_kwargs)

        return self.do_open(http_class_wrapper, req)

    # also don't raise an exception for 401 Not authorized errors
    def http_error_401(self, request, response, code, msg, hdrs):
	global options
	if options.debug:
	    print "###### Ignoring Not authorized"
	return response


# Convert an ASN1_UTCTIME to local seconds since the epoch.
# Would have used get_datetime() except it isn't supported before
#   m2crypto 0.20 which is too new for RHEL5.
def asn1time_local_secs(asn1time):
    timestruct = time.strptime(str(asn1time), '%b %d %H:%M:%S %Y GMT')
    return int(time.mktime(timestruct) - time.timezone)

### create a proxy certificate of a certificate ####
# Based on code from the gridproxy library
#  https://github.com/abbot/gridproxy/blob/master/gridproxy/__init__.py
# which is Copyright Lev Shamardin and covered under the GNU GPLv3 license.
# Returns tuple of an RFC proxy cert PEM and proxy private key PEM 
def generate_proxycert(cert, certprivkey, lifehours, limited=False, bits=2048):
    import struct

    # according to
    #   https://en.wikipedia.org/wiki/RSA_%28cryptosystem%29#Key_generation
    # the exponent 65537 (2^16+1) is most efficient
    proxyrsa = RSA.gen_key(bits, 65537, lambda x: None)
    proxykey = EVP.PKey()
    proxykey.assign_rsa(proxyrsa)

    proxy = X509.X509()
    proxy.set_pubkey(proxykey)
    proxy.set_version(2)

    now = int(time.time())
    not_before = ASN1.ASN1_UTCTIME()
    not_before.set_time(now)
    proxy.set_not_before(not_before)
    not_after = ASN1.ASN1_UTCTIME()
    not_after_time = now + int(lifehours * 60 * 60)
    # make sure proxy doesn't expire later than the underlying cert
    cert_not_after_time = asn1time_local_secs(cert.get_not_after())
    if not_after_time > cert_not_after_time:
	not_after_time = cert_not_after_time
    not_after.set_time(not_after_time)
    proxy.set_not_after(not_after)

    proxy.set_issuer_name(cert.get_subject())
    digest = EVP.MessageDigest('sha1')
    digest.update(proxykey.as_der())
    serial = struct.unpack("<L", digest.final()[:4])[0]
    proxy.set_serial_number(int(serial & 0x7fffffff))

    # It is not completely clear what happens with memory allocation
    # within the next calls, so after building the whole thing we are
    # going to reload it through der encoding/decoding.
    proxy_subject = X509.X509_Name()
    subject = cert.get_subject()
    for idx in xrange(subject.entry_count()):
        entry = subject[idx].x509_name_entry
        m2.x509_name_add_entry(proxy_subject._ptr(), entry, -1, 0)
    proxy_subject.add_entry_by_txt('CN', ASN1.MBSTRING_ASC,
				   str(serial), -1, -1, 0)
    proxy.set_subject(proxy_subject)
    proxy.add_ext(X509.new_extension("keyUsage",
	"Digital Signature, Key Encipherment, Data Encipherment", 1))
    if limited:
	proxy.add_ext(X509.new_extension("proxyCertInfo",
	    "critical, language:1.3.6.1.4.1.3536.1.1.1.9", 1))
    else:
	proxy.add_ext(X509.new_extension("proxyCertInfo",
	    "critical, language:Inherit all", 1))

    sign_pkey = EVP.PKey()
    sign_pkey.assign_rsa(certprivkey, 0)
    proxy.sign(sign_pkey, 'sha1')

    return (proxy.as_pem(), proxykey.as_pem(None))

# replace %certsubject in string with certificate subject
# the passed-in certificate may be either a base certificate or one
#   level of proxy above it
# return replaced string
def replace_certsubject(strng, cert):
    certsubject = x509name_to_str(cert.get_subject())
    issuer = x509name_to_str(cert.get_issuer())
    if certsubject.startswith(issuer):
	certsubject = issuer
    return strng.replace('%certsubject', certsubject)

# start connection to MyProxy and send it a command
# returns ssl "socket"
def start_myproxy_command(chainfile, command, username, passphrase, lifehours,
	    retrievers=None):
    # The protocol with myproxy is not https but create an HTTPS
    #  connection just to validate the certificate, then use the ssl
    #  "socket" directly.
    conn = CertValidatingHTTPSConnection(options.myproxyserver, port=7512,
	cafile=options.cafile, capath=options.capath, 
	cert_chain_file=chainfile)
    try:
	conn.connect()
    except Exception, e:
	efatal("failure connecting to MyProxy server %s" % options.myproxyserver,e)
    sslsock = conn.sock
    sslsock.write('0') # required by MyProxy protocol

    storecmd = 'VERSION=MYPROXYv2\n'
    storecmd += 'COMMAND=' + str(command) + '\n'
    storecmd += 'USERNAME=' + username + '\n'
    storecmd += 'PASSPHRASE=' + passphrase + '\n'
    storecmd += 'LIFETIME=' + str(int(lifehours * 60.0 * 60.0)) + '\n'
    if retrievers != None:
	storecmd += 'RETRIEVER_TRUSTED=' + retrievers + '\n'

    if options.debug:
	print "###### Begin MyProxy command"
	sys.stdout.write(storecmd)
	print "###### End MyProxy command"

    sslsock.write(storecmd)

    return sslsock

# Read and parse a myproxy response.  
# Returns: integer response, error text, integer end time
def parse_myproxy_response(sslsock):

    text = sslsock.recv(8192)
    if options.debug:
	print "###### Begin MyProxy response"

    response = 1
    params = {}
    for line in text.split('\n'):
	if '=' not in line:
	    continue
	if options.debug:
	    print line
	sep = line.index('=')
	key = line[0:sep]
	value = line[sep+1:]
	if key == 'RESPONSE':
	    response = int(value)
	elif key in params:
	    params[key] += ' ' + value
	else:
	    params[key] = value

    if options.debug:
	print "###### End MyProxy response"
    
    return response, params

###  cigetcert main ####
def main():
    global options
    usagestr = "usage: %prog [-h] [otheroptions]"
    parser = OptionParser(usage=usagestr, version=version, prog=prog)

    parser.add_option("-v", "--verbose", 
                      action="store_true", default=False,
                      help="write detailed progress to stdout")
    parser.add_option("-d", "--debug", 
                      action="store_true", default=False,
                      help="write debug output to stdout (implies -v)")
    parser.add_option("-q", "--quiet", 
                      action="store_true", default=False,
                      help="do not print progress or error messages")
    parser.add_option("-s", "--optserver", 
                      metavar="HostOrURL",
                      help="server or URL with default %s options" % prog)
    parser.add_option("-i", "--institution", 
                      metavar="Name",
                      help="Institution name (required)")
    parser.add_option("", "--listinstitutions", 
    		      action="store_true", default=False,
                      help="List available institution names and exit")
    parser.add_option("", "--idplisturl", 
                      metavar="URL", default=defaults['idplisturl'],
                      help="Identity Provider list URL")
    parser.add_option("", "--spurl", 
                      metavar="URL", default=defaults['spurl'],
                      help="Service Provider URL")
    parser.add_option("", "--cafile", 
                      metavar="file", default=defaults['cafile'],
                      help="Certifying Authority certificates bundle file")
    parser.add_option("", "--capath", 
                      metavar="path", default=defaults['capath'],
                      help="Certifying Authority certificates and CRLs directory")
    parser.add_option("-k", "--kerberos", 
                      action="store_true", default=False,
                      help="prefer kerberos authentication if available")
    parser.add_option("-n", "--noprompt", 
                      action="store_true", default=False,
                      help="do not prompt for password (implies --kerberos)")
    parser.add_option("-p", "--promptstr", 
                      metavar="str", default="Password for %username@%realm",
                      help="prompt string")
    parser.add_option("-u", "--username", 
                      metavar="str", default="$LOGNAME",
                      help="username for authentication")
    parser.add_option("-o", "--out", 
                      metavar="path", default="/tmp/x509up_u%uid",
                      help="file path to save certificate and key chain")
    parser.add_option("", "--minhours", 
                      type="float", metavar="num", default=12,
                      help="minimum hours remaining in existing cert chain " + \
			    "to keep using it instead of making a new one")
    weekhours = 24 * 7
    weekstr = str(weekhours)
    yearhours = int(24 * (365.5 + 31))
    yearstr = str(yearhours)
    defproxyhours = weekhours
    defproxystr = str(defproxyhours)
    parser.add_option("", "--hours", 
                      type="float", metavar="num", default=weekhours,
                      help="lifetime hours of the certificate [max: " + \
			weekstr + " unless --myproxyserver is set, then " + \
			yearstr + "]")
    parser.add_option("", "--proxyhours", 
                      type="float", metavar="num",
                      help="lifetime hours of a proxy certificate " + \
			"[max: "  + weekstr + "] [default: %hours, or " + \
			defproxystr + " if %hours > " + weekstr + "]")
    parser.add_option("", "--proxy", 
                      action="store_true", default=False,
                      help="store proxy certificate instead of certificate in %out" + \
		        " [implied when %hours does not match %proxyhours]")
    parser.add_option("", "--myproxyserver", 
                      metavar="Host",
                      help="host name of MyProxy server for storing credentials")
    parser.add_option("", "--myproxyusername", 
                      metavar="str", default="%certsubject",
                      help="username on MyProxy server for naming credentials")
    parser.add_option("", "--myproxyretrievers", 
                      metavar="expr",
                      help="regular expression of certificate Distinguished" + \
		      	    " Names permitted to fetch %myproxyusername proxy from MyProxy")
    parser.add_option("", "--myproxyhours", 
                      type="float", metavar="num",
                      help="max lifetime hours of a proxy fetched from MyProxy" + \
			" [max: " + weekstr + "] [default: %proxyhours]")


    # add default value (if any) to the help messages that are strings
    for option in parser.option_list:
	if (option.default != ("NO", "DEFAULT")) and (option.action == "store"):
	    option.help += " [default: %default]"

    (options, args) = parser.parse_args()
    if len(args) != 0:
	usage(parser, "no non-option arguments expected")

    # Set up https handler/opener with cookies
    cookiejar = cookielib.CookieJar()
    cookiehandler = urllib2.HTTPCookieProcessor(cookiejar)
    httpshandler = VerifiedHTTPSHandler(cafile=options.cafile, capath=options.capath)
    if options.debug:
	httpshandler.set_http_debuglevel(1)
    # need to avoid redirects
    class NoRedirectHandler(urllib2.HTTPRedirectHandler):
	def http_error_302(self, request, response, code, msg, hdrs):
	    if options.debug:
		print "###### Ignoring redirect"
	    return response
    noredirecthandler = NoRedirectHandler()
    opener = urllib2.build_opener(cookiehandler, noredirecthandler, httpshandler)
    if options.optserver != None:
	# read additional options from optserver
	optserver = options.optserver
	if optserver.find('://') == -1:
	    optserver = 'https://' + optserver + '/' + prog + 'opts.txt'
	if options.verbose or options.debug:
	    print "Fetching options from " + optserver
	optrequest = urllib2.Request(url=optserver)
	try:
	    opthandle = opener.open(optrequest)
	except Exception, e:
	    efatal("fetch of options from %s failed" % optserver, e)
	opts = opthandle.read()
	if options.debug:
	    print "##### Begin additional options"
	    print opts
	    print "##### End additional options"
	try:
	    extraargs = shlex.split(opts, True)
	except Exception, e:
	    efatal("parsing options from %s failed" % optserver, e)
	(options, args) = parser.parse_args(extraargs + sys.argv[1:])
	if len(args) != 0:
	    usage(parser, "non-option arguments found at %s" % optserver)

    if options.listinstitutions:
	idplistrequest = urllib2.Request(url=options.idplisturl)
	try:
	    idplisthandle = opener.open(idplistrequest)
	except Exception, e:
	    efatal("fetch of idplist from %s failed" % options.idplisturl, e)
	idplist = idplisthandle.read()

	prevname = ''
	for line in idplist.splitlines():
	    idx = line.index(' ')
	    name = line[idx+1:].replace(' (Kerberos)','')
	    if name != prevname:
		print name
	    prevname = name
	sys.exit(0)

    checkRequiredOptions(parser)

    # calculate defaults for options that are too complex for "default" keyword
    if options.hours > weekhours:
	if options.proxyhours == None:
	    options.proxyhours = defproxyhours
    elif options.proxyhours == None:
	options.proxyhours = options.hours
    if options.myproxyhours == None:
	options.myproxyhours = options.proxyhours

    # check for min and max
    if options.minhours < 0:
	fatal('--minhours must be non-negative')
    if options.hours < 0:
	fatal('--hours must be non-negative')
    if (options.hours > weekhours) and (options.myproxyserver == None):
	fatal('--hours > ' + weekstr + ' and --myproxyserver not set')
    if options.hours > yearhours:
	fatal('--hours must be <= ' + yearstr)
    if options.proxyhours < 0:
	fatal('--proxyhours must be non-negative')
    if options.proxyhours > weekhours:
	fatal('--proxyhours must <= ' + weekstr)
    if options.proxyhours > options.hours:
	fatal('--proxyhours must <= --hours')
    if options.myproxyhours < 0:
	fatal('--myproxyhours must be non-negative')
    if options.myproxyhours > weekhours:
	fatal('--myproxyhours must <= ' + weekstr)
    if options.myproxyhours > options.hours:
	fatal('--myproxyhours must <= --hours')

    # set implied options
    if options.debug:
	options.verbose = True
    if options.noprompt:
	options.kerberos = True
    if options.hours != options.proxyhours:
	options.proxy = True
    global showprogress
    if not options.quiet and not options.verbose:
	showprogress = True

    if options.debug:
	print "###### Durations:"
	print "minhours: " + str(options.minhours)
	print "hours: " + str(options.hours)
	print "proxyhours: " + str(options.proxyhours)
	print "myproxyhours: " + str(options.myproxyhours)
	print

    ### Check to see if an adequate proxy or cert already exists
    username = options.username.replace("$LOGNAME", os.getenv("LOGNAME"))
    myproxyusername = options.myproxyusername.replace("%username",username)
    outfile = options.out.replace("%uid", str(os.getuid()))
    try:
	existing = X509.load_cert(outfile)
    except Exception, e:
	if options.debug:
	    print 'Could not load ' + outfile + ': ' + type(e).__name__ + ': ' + str(e), e
    else:
	if options.verbose:
	    print "Checking if %s has at least %s hours left" % (outfile, options.minhours)
	elif showprogress:
	    sys.stdout.write('Checking if ' + outfile + ' can be reused ...')
	    sys.stdout.flush()
	time_left = asn1time_local_secs(existing.get_not_after()) - time.time()
	if time_left < 0:
	    time_left = 0
	existingdn = existing.get_subject()
	if existingdn.organizationName != options.institution:
	    if options.verbose:
		print "The organization name does not match institution, skipping"
	    elif showprogress:
		sys.stdout.write('.')
		sys.stdout.flush()
	elif time_left <= (options.minhours * 60 * 60):
	    if options.verbose:
		print "%.2f hours remaining, not enough" % (time_left / 60.0 / 60.0)
	    elif showprogress:
		sys.stdout.write('.')
		sys.stdout.flush()
	else:
	    if options.verbose:
		print "%.2f hours remaining, enough to reuse" % (time_left / 60.0 / 60.0)
	    canreuse = False
	    if options.myproxyserver == None:
		canreuse = True
	    else:
		minhours = options.hours - options.proxyhours - options.minhours
		if options.verbose:
		    print "Checking if %s has at least %s hours left" % \
			    (options.myproxyserver, minhours)
		elif showprogress:
		    sys.stdout.write('.')
		    sys.stdout.flush()

		myproxyusername = replace_certsubject(myproxyusername, existing)

		sslsock = start_myproxy_command(outfile, '2', myproxyusername,
				'PASSPHRASE', 0)

		response, params = parse_myproxy_response(sslsock)
		if response:
		    if options.debug:
			print "##### Begin MyProxy error text"
			print params['ERROR']
			print "##### End MyProxy error text"
		    if options.verbose:
			print "No info retrieved from MyProxy, continuing"
		elif 'CRED_END_TIME' not in params:
		    fatal('no CRED_END_TIME in info retrieved from MyProxy')
		else:
		    endtime = int(params['CRED_END_TIME'])
		    time_left = endtime - int(time.time())
		    if time_left < 0:
			time_left = 0

		    if time_left <= (minhours * 60 * 60):
			if options.verbose:
			    print "%.2f hours remaining, not enough" % \
			    	(time_left / 60.0 / 60.0)
		    else:
			if options.verbose:
			    print "%.2f hours remaining, enough to reuse" % \
			    	(time_left / 60.0 / 60.0)
			if options.myproxyretrievers != None:
			    if showprogress:
				sys.stdout.write('.')
				sys.stdout.flush()
			    if ('CRED_RETRIEVER_TRUSTED' not in params) or \
				(params['CRED_RETRIEVER_TRUSTED'] != options.myproxyretrievers):
				if options.debug:
				    print "##### Begin MyProxy retrievers"
				    if 'CRED_RETRIEVER_TRUSTED' not in params:
					print 'None'
				    else:
					print params['CRED_RETRIEVER_TRUSTED']
				    print "##### End MyProxy retrievers"
				if options.verbose:
				    print "However the myproxyretrievers does not match"
			    else:
				if options.debug:
				    print "myproxyretrievers also matches"
				canreuse = True
			else:
			    canreuse = True

	    if canreuse:
		if showprogress:
		    print " yes"
		sys.exit(0)

	if showprogress:
	    print " no"

    ### Look up the IdP URL
    if options.verbose:
	print "Fetching list of IdPs from " + options.idplisturl
    elif showprogress:
	sys.stdout.write("Authorizing ...")
	sys.stdout.flush()
    idplistrequest = urllib2.Request(url=options.idplisturl)
    try:
	idplisthandle = opener.open(idplistrequest)
    except Exception, e:
	efatal("fetch of idplist from %s failed" % options.idplisturl, e)
    idplist = idplisthandle.read()

    idpurl = None
    idpkrburl = None
    for line in idplist.splitlines():
	idx = line.index(' ')
	name = line[idx+1:]
	if re.match(options.institution + '($| \()', name) != None:
	    url = line[0:idx]
	    if line.endswith(' (Kerberos)'):
		idpkrburl = url
	    else:
		idpurl = url
    if idpkrburl == None:
	# if there's no server explicitly marked for kerberos, it's
	#   possible the regular server supports it
	idpkrburl = idpurl

    if idpkrburl == None:
	fatal('No institution called "' + options.institution + '"\n' +
		'  in ' + options.idplisturl + '\n' +
		'  Use --listinstitutions to see available institutions')

    if options.debug:
	print '##### IdP URL: ' + str(idpurl)
	if options.kerberos:
	    print '##### Kerberos IdP URL: ' + str(idpkrburl)

    ### Begin the real SAML communication, starting with the SP ###
    headers = {
	'Accept' : 'text/html; application/vnd.paos+xml',
	'PAOS'   : 'ver="urn:liberty:paos:2003-08";"urn:oasis:names:tc:SAML:2.0:profiles:SSO:ecp"'
    }
    if not options.spurl.endswith('/'):
	options.spurl += '/'
    if options.verbose:
	print "Requesting authorization from SP " + options.spurl
    elif showprogress:
	sys.stdout.write('.')
	sys.stdout.flush()
    sprequest = urllib2.Request(url=options.spurl,headers=headers)
    try:
	sphandle = opener.open(sprequest)
    except Exception, e:
	efatal("first request to SP %s failed" % options.spurl, e)
    
    spetree = etree.XML(sphandle.read())

    if options.debug:
	print "##### Begin SP response"
	print etree.tostring(spetree, pretty_print=True)
	print "##### End SP response"

    # these are used for multiple XML parses below
    namespaces = {
        'ecp' : 'urn:oasis:names:tc:SAML:2.0:profiles:SSO:ecp',
        'S'   : 'http://schemas.xmlsoap.org/soap/envelope/',
        'paos': 'urn:liberty:paos:2003-08'
    }

    # pull out the RelayState
    try:
	relayState = spetree.xpath("//ecp:RelayState", namespaces=namespaces)[0]
    except Exception, e:
	efatal("Unable to parse RelayState element from SP response", e)

    if options.debug:
	print "###### Begin RelayState element"
	print etree.tostring(relayState, pretty_print=True)
	print "###### End RelayState element"

    # pull out the responseConsumerURL
    try:
	responseConsumerURL = spetree.xpath("/S:Envelope/S:Header/paos:Request/@responseConsumerURL", namespaces=namespaces)[0]
    except Exception, e:
	efatal("Unable to parse responseConsumerURL from SP response",  e)

    if options.debug:
	print "###### Begin responseConsumerUrl attribute"
	print responseConsumerURL
	print "###### End responseConsumerUrl attribute"

    # remove the SOAP header to pass the AuthnRequest on to the IdP
    idprequestbody = spetree
    header = idprequestbody[0]
    idprequestbody.remove(header)
    # can't pretty print here or the IdP doesn't like it
    idpbody=etree.tostring(idprequestbody)

    if options.debug:
	print "###### Begin IdP request body"
        print etree.tostring(idprequestbody, pretty_print=True)
	print "###### End IdP request body"

    wwwauthenticate = ""
    unauthidpurl = ""
    authidpurl = ""
    if options.kerberos:
	# try Kerberos first
	unauthidpurl = idpkrburl
    elif options.noprompt:
	# this should be disallowed above since noprompt implies kerberos
	fatal("programming error - neither kerberos nor prompt selected")
    else:
	unauthidpurl = idpurl

    def dounauthrequest(url):
	if options.verbose:
	    print "Making unauthorized request to IdP " + url
	elif showprogress:
	    sys.stdout.write('.')
	    sys.stdout.flush()
	idprequest = urllib2.Request(url=url)
	try:
	    notauthidphandle = opener.open(idprequest)
	except Exception, e:
	    efatal("Failure on (deliberately) unauthorized request to IdP %s" % url, e)

	if options.debug:
	    print "###### Begin IdP response to unauthorized request"
	    print notauthidphandle.info()
	    print "###### End IdP response to unauthorized request"

	if notauthidphandle.code != 401:
	    fatal("Did not get expected response code 401 from IdP %s, instead got code %d" % (url, notauthidphandle.code))
	return www_auth(notauthidphandle)

    wwwauthenticate = dounauthrequest(unauthidpurl)

    idphandle = None
    if options.kerberos:
	if 'negotiate' in wwwauthenticate:
	    netloc = urlparse.urlsplit(idpkrburl)[1]
	    hostname = re.sub(":.*", "", netloc)
	    service = "HTTP@" + hostname
	    if options.debug:
		print "###### Initializing kerberos context for " + service
	    __, krb_context = kerberos.authGSSClientInit(service)
	    try:
		kerberos.authGSSClientStep(krb_context, "")
	    except Exception, e:
		if options.noprompt:
		    efatal("Kerberos initialization failed", e)
		if options.verbose:
		    print "Kerberos initialization failed: %s" % e
		    print "Trying password"
		elif showprogress:
		    sys.stdout.write('.')
		    sys.stdout.flush()
	    else:
		negotiate_details = kerberos.authGSSClientResponse(krb_context)
		headers = {
		    'Content-Type': 'text/xml',
		    'Authorization': 'Negotiate ' + negotiate_details
		}

		# Redo it with kerberos, sending the AuthnRequest in a POST
		if options.verbose:
		    print "Making kerberized request to IdP " + idpkrburl
		elif showprogress:
		    sys.stdout.write('.')
		    sys.stdout.flush()

		authidpurl = idpkrburl
		idprequest = urllib2.Request(idpkrburl, headers=headers, data=idpbody)
		try:
		    idphandle = opener.open(idprequest)
		except Exception, e:
		    idphandle = None
		    efatal("Failure on response from IdP %s" % idpkrburl, e)

    if not options.noprompt and (idphandle == None):
	if idpurl != unauthidpurl:
	    wwwauthenticate = dounauthrequest(idpurl)

	if 'basic' not in wwwauthenticate:
	    fatal("IdP does not support password authentication")

	# ask for password
	promptstr = options.promptstr.replace("%username",username)
	if (promptstr.find('%realm') != -1):
	    basic = wwwauthenticate['basic']
	    if basic.find('realm=') == -1:
		fatal("IdP did not supply realm for password prompt")
	    realm = re.sub('.*realm="', '', basic)
	    realm = re.sub('".*', '', realm)
	    promptstr = promptstr.replace("%realm", realm)
	if showprogress:
	    print
	password = getpass.getpass(promptstr + ': ')
	base64string = base64.encodestring('%s:%s' % (username, password)).replace('\n', '')
	headers = {
	    'Content-Type': 'text/xml',
	    'Authorization': 'Basic ' + base64string
	}

	# POST the AuthnRequest to the IDP
	if options.verbose:
	    print "Making authorized request to IdP " + idpurl
	authidpurl = idpurl
	idprequest = urllib2.Request(idpurl, headers=headers, data=idpbody)
	try:
	    idphandle = opener.open(idprequest)
	except Exception, e:
	    efatal("Failure on response from IdP %s" % idpurl, e)

    if idphandle.code != 200:
	# in case unauthorized the second try
	if idphandle.code == 401:
	    fatal("authorization failed")
	fatal("unexpected http response code from IdP %s: %d" % (authidpurl, idphandle.code))

    idpetree = etree.XML(idphandle.read())

    if options.debug:
	print "###### Begin IdP response"
	print etree.tostring(idpetree, pretty_print=True)
	print "###### End IdP response"

    # pull out the AsssertionConsumerServiceURL
    try:
	assertionConsumerServiceURL = idpetree.xpath("/S:Envelope/S:Header/ecp:Response/@AssertionConsumerServiceURL", namespaces=namespaces)[0]
    except Exception, e:
	efatal("Unable to parse AssertionConsumerServiceURL from IdP response",  e)

    if options.debug:
	print "###### Begin AssertionConsumerServiceURL attribute"
	print assertionConsumerServiceURL
	print "###### End AssertionConsumerServiceURL attribute"

    if assertionConsumerServiceURL != responseConsumerURL:
	# IdP's response doesn't match SP's expectation
	if options.verbose:
	    print "Telling SP that IdP had a response error"
        soapfault = """
            <S:Envelope xmlns:S="http://schemas.xmlsoap.org/soap/envelope/">
               <S:Body>
                 <S:Fault>
                    <faultcode>S:Server</faultcode>
                    <faultstring>responseConsumerURL from SP and assertionConsumerServiceURL from IdP do not match</faultstring>
                 </S:Fault>
               </S:Body>
            </S:Envelope>
            """
        headers = { 'Content-Type' : 'application/vnd.paos+xml' }
	request = urllib2.Request(responseConsumerURL, headers=headers, data=soapfault)
        # POST the fault to the SP but ignore any failure
        try:
            handle = opener.open(request)
        except Exception, e:
            pass

	fatal("assertionConsumerServiceURL %s from IdP does not match responseConsumerURL %s from SP" % (assertionConsumerServiceURL, responseConsumerURL))

    if showprogress:
	print ' authorized'

    # replace the header of the idp response with the relay state sent by the
    #  Assertion Consumer (which is on the SP)
    acrequestbody = idpetree
    acrequestbody[0][0] = relayState
    acbody=etree.tostring(acrequestbody)

    if options.debug:
	print "###### Begin SP Assertion Consumer body"
        print etree.tostring(acrequestbody, pretty_print=True)
	print "###### End SP Assertion Consumer body"


    if options.verbose:
	print "Sending response to Assertion Consumer " + assertionConsumerServiceURL
    elif showprogress:
	sys.stdout.write('Fetching certificate ...')
	sys.stdout.flush()
    headers = { 'Content-Type' : 'application/vnd.paos+xml' }
    acrequest = urllib2.Request(assertionConsumerServiceURL, headers=headers, data=acbody)
    try:
	achandle = opener.open(acrequest)
    except Exception, e:
	efatal("Failure on response from assertion consumer %s" % assertionConsumerServiceURL, e)

    # Ignore the response body. We only want the cookie which the opener
    #   has already stored in the cookiejar.

    shibcookie = cookiejar.make_cookies(achandle, acrequest)[0]
    if options.debug:
	print "###### Begin shibboleth cookie"
        print [shibcookie]
	print "###### End shibboleth cookie"

    def random_string(length, outof=string.ascii_lowercase+string.digits):
	# http://stackoverflow.com/a/23728630/2213647 says SystemRandom()
	#  is most secure
	return ''.join(random.SystemRandom().choice(outof) for _ in range(length))

    # Add a 10-character random Cross Site Request Forgery prevention cookie.
    # It also has to be a form value in order to pass the CILogon CSRF check.
    csrfstr = random_string(10)
    headers = {
	'Content-Type' : 'application/x-www-form-urlencoded',
	'Cookie' : 'CSRF=' + csrfstr + '; ' +
	    shibcookie.name + '=' + shibcookie.value
    }

    # Choose a random password for encrypting pkcs12 cert/key over the link.
    # The ascii letters are for strength, the digits and special characters
    #  are just in case future rules enforce such things.
    # Could instead use a CSR but that limits certificates to 277 hours.
    p12password = random_string(16, string.ascii_letters) + \
	random_string(2, string.digits) + random_string(2, '!@#$%^&*()')

    certformvars = [
	('submit' , 'pkcs12'),
	('CSRF' , csrfstr),
	('p12password' , p12password),
	('p12lifetime' , int(options.hours))
    ]
    certformdata = urllib.urlencode(certformvars)

    if options.verbose:
	print "Requesting certificate from SP " + options.spurl
    elif showprogress:
	sys.stdout.write('.')
	sys.stdout.flush()
    spcertrequest = urllib2.Request(url=options.spurl,data=certformdata, 
    		headers=headers)
    try:
	spcerthandle = opener.open(spcertrequest)
    except Exception, e:
	efatal("cert request to SP %s failed" % options.spurl,e)

    pkcs12cert = spcerthandle.read()
    if options.debug:
	print "Read %d bytes of encrypted pkcs12 certificate" % len(pkcs12cert)

    if options.verbose:
	print "Converting PKCS12 certificate to PEM"
    elif showprogress:
	sys.stdout.write('.')
	sys.stdout.flush()

    # M2Crypto does not support pkcs12 so need to use pyOpenSSL
    try:
	p12 = crypto.load_pkcs12(pkcs12cert, p12password)
    except Exception, e:
	efatal("could not decode certificate from SP %s" % options.spurl,e)

    cert = p12.get_certificate()
    key = p12.get_privatekey()

    # convert back to M2Crypto objects
    certstr = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
    cert = X509.load_cert_string(certstr)
    keystr = crypto.dump_privatekey(crypto.FILETYPE_PEM, key)
    key = EVP.load_key_string(keystr).get_rsa()

    if options.debug:
	print "###### Begin certificate"
        sys.stdout.write(certstr)
	print "###### End certificate"
        # deliberately not printing key, to prevent somebody from
        #  accidentally storing it on disk if they redirect stdout
        #  when using the debug option.
    if showprogress:
	print ' fetched'

    proxyorcert = 'certificate'
    if options.proxy:
	proxyorcert = 'proxy'
	if options.verbose:
	    print "Generating proxy for storage"
	elif showprogress:
	    sys.stdout.write('Generating proxy ...')
	    sys.stdout.flush()
	try:
	    (proxystr, proxykeystr) = generate_proxycert(cert, key, options.proxyhours)
	    proxy = X509.load_cert_string(proxystr)
	except Exception, e:
	    efatal("failure generating proxy for storage", e)
	if showprogress:
	    print ' generated'

    if options.verbose or showprogress:
	print 'Storing ' + proxyorcert + ' in ' + outfile
    handle = os.fdopen(os.open(outfile, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0600), 'w')

    if options.proxy:
	handle.write(proxystr)
	handle.write(proxykeystr)
	handle.write(certstr)
	firstcert = proxy
    else:
	handle.write(certstr)
	handle.write(keystr)
	firstcert = cert
    handle.close()

    if options.verbose:
        print "subject  : " + x509name_to_str(firstcert.get_subject())
        print "issuer   : " + x509name_to_str(firstcert.get_issuer())

    if options.verbose or showprogress:
	validuntil = time.ctime(asn1time_local_secs(firstcert.get_not_after()))
	print 'Your ' + proxyorcert + ' is valid until: ' + validuntil

    ### MyProxy handling section
    if options.myproxyserver == None:
	sys.exit(0)

    if options.verbose:
	print "Generating proxy for MyProxy"
    elif showprogress:
	sys.stdout.write('Generating proxy for MyProxy ...')
	sys.stdout.flush()
    try:
	(myproxystr, myproxykeystr) = generate_proxycert(cert, key, options.hours)
    except Exception, e:
	efatal("failure generating proxy for MyProxy", e)
    if showprogress:
	print ' generated'

    if options.verbose:
	print "Storing proxy in MyProxy server " + options.myproxyserver
    elif showprogress:
	sys.stdout.write("Storing proxy in MyProxy ...")
	sys.stdout.flush()

    myproxyusername = replace_certsubject(myproxyusername, firstcert)

    sslsock = start_myproxy_command(outfile, '5', myproxyusername,
		    '', options.myproxyhours, options.myproxyretrievers)

    response, params = parse_myproxy_response(sslsock)
    if response:
	fatal('error from MyProxy on store request: ' + params['ERROR'])

    if showprogress:
	sys.stdout.write('.')
	sys.stdout.flush()

    if options.debug:
	print "###### Begin chain sending to MyProxy"
	sys.stdout.write(myproxystr + myproxykeystr + certstr)
	print "###### End chain sending to MyProxy"

    # these have to all be sent in one write or sometimes MyProxy 
    #  doesn't read all the pieces properly
    sslsock.send(myproxystr + myproxykeystr + certstr)

    response, params = parse_myproxy_response(sslsock)
    if response:
	fatal('error from MyProxy on store: ' + params['ERROR'])

    if showprogress:
	print ' stored'

if __name__ == '__main__':
    main()
