Browse Source

Working Connection Handler

Fred Damstra [afs macbook] 3 years ago
parent
commit
b8c0c7de47

+ 197 - 0
base/aws_client_vpn/files/connection_authorization/connection_handler.py

@@ -0,0 +1,197 @@
+# A connection handler to check if somebody is already connected to the VPN, and if so, to disconnect them.
+#
+# References:
+#   https://docs.aws.amazon.com/vpn/latest/clientvpn-admin/connection-authorization.html
+#   https://aws.amazon.com/blogs/networking-and-content-delivery/enforcing-vpn-access-policies-with-aws-client-vpn-connection-handler/
+#
+# Example input event:
+#     {
+#         "connection-id": <connection ID>,
+#         "endpoint-id": <client VPN endpoint ID>,
+#         "common-name": <cert-common-name>,
+#         "username": <user identifier>,
+#         "platform": <OS platform>,
+#         "platform-version": <OS version>,
+#         "public-ip": <public IP address>,
+#         "client-openvpn-version": <client OpenVPN version>,
+#         "schema-version": "v1"
+#     }
+#
+# Example output:
+#     {
+#         "allow": boolean,
+#         "error-msg-on-failed-posture-compliance": "",
+#         "posture-compliance-statuses": [],
+#         "schema-version": "v1"
+#     }
+#
+# Boto3 stuff:
+#   describe_client_vpn_connections: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html?highlight=describeclientvpnendpoints#EC2.Client.describe_client_vpn_connections
+#              paginator: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html?highlight=describeclientvpnendpoints#EC2.Paginator.DescribeClientVpnConnections
+#   terminate_client_vpn_connections: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html?highlight=describeclientvpnendpoints#EC2.Client.terminate_client_vpn_connections
+
+
+import boto3
+import boto3.session
+import datetime
+import json
+import logging
+
+# Configuration
+DISCONNECT_EXISTING=False # The client automatically reconnects. Best we can do is not allow new connections.
+
+# Globals
+client = None
+session = None
+logger = logging.getLogger()
+logger.setLevel(logging.INFO)
+
+def disconnect(username, endpoint):
+    ''' Returns True if the username is presently connected to endpoint '''
+    # In practice, this code runs indefinitely, because by the time it terminates one connection, another has been formed.
+    global client
+    logger.info(f'Disconnecting user "{username}" from endpoint "{endpoint}"')
+    all_disconnected = False
+
+    while not all_disconnected:
+        try:
+            response = client.terminate_client_vpn_connections(ClientVpnEndpointId=endpoint, Username=username)
+        except Exception as e:
+            logger.error(f'Exception while trying to disconnect. Exception was {str(e)}')
+            return False
+        if len(response['ConnectionStatuses']) == 0:
+            all_disconnected = True
+        else:
+            for c in response['ConnectionStatuses']:
+                status = c.get('CurrentStatus', {}).get('Code')
+                message = c.get('CurrentStatus', {}).get('Message')
+                if status != 'terminating' and status != 'terminated':
+                    logger.error(f'Failed to disconnect "{username}". Message is: {message}')
+                    return False
+    return True
+
+
+def is_connected(username, endpoint):
+    ''' Returns True if the username is presently connected to endpoint '''
+    # This isn't really necessary, I don't think. If you call disconnect, it'll disconnect all of them.
+    # But it's useful for intelligence gathering.
+    global client
+    logger.debug(f'Checking if user "{username}" is connected to endpoint "{endpoint}"')
+    paginator = client.get_paginator('describe_client_vpn_connections')
+    response_iterator = paginator.paginate(
+        ClientVpnEndpointId=endpoint,
+        Filters=[
+            {
+                'Name': 'username',
+                'Values': [
+                    username,
+                ]
+            },
+        ],
+        DryRun=False,
+    )
+    for r in response_iterator:
+        connections = r.get('Connections')
+        if len(connections) == 0:
+            logger.debug(f'User "{username}" is not connected.')
+            return False
+        else:
+            for c in connections:
+                logger.debug(f'Evaluating connection: {json.dumps(c, indent=2, default=str)}')
+                if c['Status']['Code'] != 'terminated' and c['Status']['Code'] != 'terminating':
+                    logger.debug(f'User "{username}" is connected.')
+                    logger.debug(f'Details: {json.dumps(connections, indent=2, default=str)}')
+                    return True
+    return False # User is no longer connected
+
+
+
+def lambda_handler(event, context):
+    global client, session
+    try:
+        if session:
+            client = session.client('ec2')
+        else:
+            client = boto3.client('ec2')
+    except Exception as e:
+        logger.error(f'Could not create client session. Error was: {str(e)}')
+        return {
+            "allow": False,
+            "error-msg-on-failed-posture-compliance": str(e),
+            "posture-compliance-statuses": [],
+            "schema-version": "v1"
+        }
+
+    username = event.get('username', None)
+    endpoint = event.get('endpoint-id', None)
+
+    if is_connected(username, endpoint):
+        if DISCONNECT_EXISTING:
+            if disconnect(username, endpoint):
+                logger.info(f'Disconnecting existing session for "{username}" and allowing the new connection.')
+                return {
+                    "allow": True,
+                    "error-msg-on-failed-posture-compliance": '',
+                    "posture-compliance-statuses": [],
+                    "schema-version": "v1"
+                }
+            else:
+                # Error during disconnect?
+                logger.error(f'Unable to disconnect user "{username}"')
+                return {
+                    "allow": False,
+                    "error-msg-on-failed-posture-compliance": 'Unable to disconnect',
+                    "posture-compliance-statuses": [],
+                    "schema-version": "v1"
+                }
+        else:
+            logger.info(f'User "{username}" is connected. Not allowing new connection.')
+            return {
+                "allow": False,
+                "error-msg-on-failed-posture-compliance": 'Your account is already connected. XDR VPN connections are limited to one connection at a time. If you have recently disconnected, please wait approximately 2 minutes and attempt your connection again.',
+                "posture-compliance-statuses": [],
+                "schema-version": "v1"
+            }
+    # User is not connected, allow the connection
+    return {
+        "allow": True,
+        "error-msg-on-failed-posture-compliance": '',
+        "posture-compliance-statuses": [],
+        "schema-version": "v1"
+    }
+
+
+def main():
+    ''' main() performs local testing. '''
+    # Set up logging to stdout
+    global logger
+    handler = logging.StreamHandler()
+    logger.addHandler(handler)
+
+    # Turn off debugging logs for AWS
+    for module in [ 'boto3', 'botocore', 'nose', 's3transfer', 'urllib3', 'urllib3.connectionpool' ]:
+            l = logging.getLogger(module)
+            l.setLevel(logging.INFO)
+
+    logger.setLevel(logging.DEBUG) # Debug logs for running locally. In the cloud runs as 'info'
+
+    # Set up the boto3 session
+    global session
+    session = boto3.session.Session(profile_name='mdr-test-c2-gov', region_name='us-gov-east-1')
+
+    test_event = {
+        "connection-id": 'testconnectionid',
+        "endpoint-id": 'cvpn-endpoint-0a4ccca3756e984c9',
+        "common-name": '<cert-common-name>',
+        "username": 'frederick.t.damstra',
+        "platform": '<OS platform>',
+        "platform-version": '<OS version>',
+        "public-ip": '<public IP address>',
+        "client-openvpn-version": '<client OpenVPN version>',
+        "schema-version": "v1"
+    }
+    result = lambda_handler(event=test_event, context={})
+    print(f'Result: {json.dumps(result, indent=2, default=str)}')
+
+if __name__ == "__main__":
+    main()

+ 114 - 0
base/aws_client_vpn/lambda.tf

@@ -0,0 +1,114 @@
+# Lambda function to refuse concurrent connections
+data "archive_file" "lambda_connection_authorization" {
+  type             = "zip"
+  source_file      = "${path.module}/files/connection_authorization/connection_handler.py"
+  # 0666 results in "more consistent behavior" according to https://registry.terraform.io/providers/hashicorp/archive/latest/docs/data-sources/archive_file
+  output_file_mode = "0666"
+  output_path      = "${path.module}/files/connection_authorization/connection_handle.zip"
+}
+
+resource "aws_iam_role" "lambda_connection_authorization" {
+  name = "awsclientvpn-connection-handler"
+  path = "/lambda/"
+
+  assume_role_policy = <<EOF
+{
+  "Version": "2012-10-17",
+  "Statement": [
+    {
+      "Action": "sts:AssumeRole",
+      "Principal": {
+        "Service": "lambda.amazonaws.com"
+      },
+      "Effect": "Allow",
+      "Sid": ""
+    }
+  ]
+}
+EOF
+}
+
+data "aws_iam_policy_document" "lambda_connection_authorization_policy_doc" {
+  statement {
+    sid       = ""
+    effect    = "Allow"
+    resources = ["*"]
+
+    actions = [
+      "ec2:DescribeClientVpnConnections",
+      "ec2:TerminateClientVpnConnections",
+      "logs:CreateLogStream",
+      "logs:CreateLogGroup",
+      "logs:PutLogEvents",
+    ]
+  }
+}
+
+resource "aws_iam_policy" "lambda_connection_authorization_policy" {
+  name        = "awsclientvpn-connection-handler"
+  path        = "/lambda/"
+  policy      = data.aws_iam_policy_document.lambda_connection_authorization_policy_doc.json
+}
+
+resource "aws_iam_role_policy_attachment" "lambda_connection_authorization_policy_attachment" {
+  role       = aws_iam_role.lambda_connection_authorization.name
+  policy_arn = aws_iam_policy.lambda_connection_authorization_policy.arn
+}
+
+resource "aws_lambda_function" "lambda_connection_authorization" {
+  function_name = "AWSClientVPN-ConnectionHandler"
+  description   = "Only allows one concurrent connection"
+  runtime       = "python3.9"
+  memory_size   = 128
+  publish       = true
+  timeout       = 30 # Cannot be changed (maybe can be reduced?)
+  filename      = data.archive_file.lambda_connection_authorization.output_path
+  role          = aws_iam_role.lambda_connection_authorization.arn
+  handler       = "connection_handler.lambda_handler"
+
+  source_code_hash = data.archive_file.lambda_connection_authorization.output_base64sha256
+
+  #environment {
+  #  variables = {
+  #    # TODO: Set logging level
+  #  }
+  #}
+  
+  tags = merge(var.standard_tags, var.tags)
+}
+
+
+#module "lambda_function" {
+#  source = "terraform-aws-modules/lambda/aws"
+#
+#  function_name = "AWSClientVPN-ConnectionHandler"
+#  description   = "Determines whether user is allowed to log in."
+#  handler       = "connection_handler.lambda_handler"
+#  runtime       = "python3.9"
+#  timeout       = 30 # Cannot be changes on a connection handler
+#  publish       = true
+#
+#  source_path = "${path.module}/files/connection_authorization/connection_handler.py"
+#
+#  attach_policy_json = true
+#  policy_json = <<EOF
+#{
+#     "Version": "2012-10-17",
+#     "Statement": [
+#       {
+#             "Effect": "Allow",
+#             "Action": [
+#                 "ec2:DescribeClientVpnConnections",
+#                 "ec2:TerminateClientVpnConnections"
+#             ],
+#             "Resource": "*"
+#         }
+#     ]
+#}
+#EOF
+## The following 3 permissions are autoatically added by the module:
+##                 "logs:CreateLogStream",
+##                 "logs:CreateLogGroup",
+##                 "logs:PutLogEvents",
+#  tags = merge(var.standard_tags, var.tags)
+#}

+ 9 - 0
base/aws_client_vpn/outputs.tf

@@ -18,3 +18,12 @@ output "vpn_id" {
 output "self_service_url" {
   value = "https://gov.self-service.clientvpn.amazonaws.com/endpoints/${ aws_ec2_client_vpn_endpoint.vpn.id }"
 }
+
+output "lambda_function_arn" {
+  #value = module.lambda_function.lambda_function_arn
+  value = aws_lambda_function.lambda_connection_authorization.arn
+}
+
+output "lambda_function_reminder" {
+  value = "You must configure the lambda connection handler in the AWS console. VPC->Client VPN Endpoints->Modify, under 'Client Connect Handler'"
+}