Automatic Sharing and Loading RDS Snapshots Using Lambda

maj 13, 2020

This article was written by Henrik Lernmark, certified AWS Solutions Architect with 20+ years experience in the software industry.

https://www.linkedin.com/in/henriklernmark/

These code snippets make it possible to share RDS snapshots to protect and back up your database.

We have a scenario where we have a production and a staging account. In these accounts, we have RDS databases deployed. The RDS is set up to create an automatic snapshot once a day. For testing and validation purposes we want to keep the data in the staging environment current we want to load the daily snapshot from the production account into the staging account.

To accomplish this I created two lambdas. The first one runs in production and copies the latest snapshot and then shares it with the staging account. The second lambda runs in the staging account and renames the current database, loads the snapshot into a new database and then deletes the old database.

The first lambda is triggered by a CloudWatch event: “RDS DB Snapshot Event”; “Automated snapshot created.”
This will run the lambda every time an automatic backup of the database occurs and will start the copying of the latest automatically-created snapshot.

Unfortunately, you can’t share a snapshot copy while it’s being created and there’s no CloudWatch event for when a snapshot copy is done so I choose to use a waiter function. Not ideal in lambda, but as long as it’s not a giant database, it’s ok. Also if you have a huge database you probably shouldn’t be using this approach anyway for keeping data current in your staging environment.

The environment variable MAX_WAIT is used to calculate the waiter’s configuration and should be set to a lower value than the lambdas timeout.

Once the copy is created it’s shared with the account number in the SHARED_ACCOUNT environment variable

import boto3
import logging
import os
import datetime
from operator import itemgetter
logger = logging.getLogger(name=__name__)
env_level = os.environ.get("LOG_LEVEL")
log_level = logging.INFO if not env_level else env_level
logger.setLevel(log_level)
rds = boto3.client('rds')
waiter = rds.get_waiter('db_snapshot_available')
def copy_latest_snapshot(db_identifier):
  timestamp = '{:%Y%m%d-%H%M%S}'.format(datetime.datetime.now())
  snapshot_copy = f"snapshot-copy{timestamp}"
  logger.info(f"Find latest snapshot of: {db_identifier}")
  response = rds.describe_db_snapshots(DBInstanceIdentifier=db_identifier, SnapshotType='automated')
  sorted_keys = sorted(response['DBSnapshots'], key=itemgetter('SnapshotCreateTime'), reverse=True)
  snapshot_id = sorted_keys[0]['DBSnapshotIdentifier']
  logger.info(f"Wait for snapshot: {snapshot_id}")
  waiter.wait(
DBInstanceIdentifier=db_identifier, DBSnapshotIdentifier=snapshot_id, WaiterConfig={'Delay': 5, 'MaxAttempts': 12}
  )
  logger.info(f"Create snapshot copy: {snapshot_copy} of {db_identifier}"
)
  rds.copy_db_snapshot(SourceDBSnapshotIdentifier=snapshot_id, TargetDBSnapshotIdentifier=snapshot_copy)
  return snapshot_copy

def share_snapshot(db_identifier, snapshot_id):
  try:
    max_wait = int(os.environ.get("MAX_WAIT"))
    max_att = int(max_wait / 5)
    logger.info(f"Wait for snapshot: {snapshot_id}")
    waiter.wait(DBInstanceIdentifier=db_identifier, DBSnapshotIdentifier=snapshot_id, WaiterConfig={'Delay': 5, 'MaxAttempts': max_att}
    )
    shared_account = os.environ.get("SHARED_ACCOUNT")
    logger.info(f"Share snapshot: {snapshot_id} of {db_identifier}")
    rds.modify_db_snapshot_attribute( DBSnapshotIdentifier=snapshot_id, AttributeName="restore", ValuesToAdd=[shared_account]
)
  except Exception as e:
    logger.warning(e)

def copy_share():
  db_identifier = os.environ.get("RDS_DB")
  snapshot_copy = copy_latest_snapshot(db_identifier)
  share_snapshot(db_identifier, snapshot_copy)

def lambda_handler(event, context):
  if (event and event['detail-type'] == 'RDS DB Snapshot Event'):
    message = event['detail']['Message']
    src_db = event['detail']['SourceIdentifier']
    logger.info(f"{src_db}: {message}")
    if message.find('Automated snapshot created') >= 0:
      copy_share()
    else:
      logger.info(event)
  elif not event or event['detail-type'] == 'Scheduled Event':
    copy_share()
  else:
    logger.warning(f"Unhandeled event: {event}")

To run this lambda you also need to run it under an IAM role that includes the following policy:

{
  "Statement":
  [
    {
      "Action": [
        "rds:CopyDBSnapshot",
        "rds:ModifyDBSnapshot",
        "rds:DescribeDBSnapshots",
        "rds:ModifyDBSnapshotAttribute"
      ],
      "Resource": ["*"],
      "Effect": "Allow"
    }
  ]
}
The resource should be your RDS database.

The second lambda is triggered by four different CloudWatch events.

The first is a scheduled event that is set to run after the RDS backup window. This event renames the current database instance. You could skip the renaming and delete it straight away, but I choose to keep the old database for now to make it easy to revert to if the creation of the new one goes wrong.

The other events are all “RDS DB Instance Event”.

First is the “Renamed instance from RDS_DB to RDS_DB-timestamp.” This event will trigger finding the latest snapshot and start loading it using the restore_db_instance_from_db_snapshot command. In this command you can set most things specific to your environment for your RDS, except security groups! Apparently, in older versions of boto3, you were able to, according to the documentation that latest is pointing to (1.9.189), but not in the current version running in AWS (1.9.48).

This is why I needed to trigger the “Finished updating DB parameter group” event. When this event occurs I first verify that it’s the wrong security groups before calling modify_db_instance to set the correct ones. I also found that you sometimes needed a wait statement before the modification.

The final event used is “Restored from snapshot.” When this event occurs the old instance is deleted.

import boto3
import logging
import os
import datetime
from operator import itemgetter
logger = logging.getLogger(name=__name__)
env_level = os.environ.get("LOG_LEVEL")
log_level = logging.INFO if not env_level else env_level
logger.setLevel(log_level)
rds = boto3.client('rds')
waiter = rds.get_waiter('db_instance_available')

def get_latest_snapshot_arn():
  response = rds.describe_db_snapshots( IncludeShared=True, SnapshotType="shared",)
  if not response['DBSnapshots']:
    return
  sorted_keys = sorted(response['DBSnapshots'], key=itemgetter('SnapshotCreateTime'), reverse=True)
  snapshot_arn = sorted_keys[0]['DBSnapshotArn']
  return snapshot_arn

def rename_current_db(db_identifier):
  new_identifier = ''
  response = rds.describe_db_instances( Filters=[
{'Name': 'db-instance-id', 'Values': [db_identifier]}, ])
  if response['DBInstances']:
    timestamp = '{:%Y%m%d-%H%M%S}'.format(datetime.datetime.now())
    new_identifier = f"{db_identifier}-{timestamp}"
    logger.info(f"Rename {db_identifier} to {new_identifier}")
    rds.modify_db_instance(
DBInstanceIdentifier=db_identifier, NewDBInstanceIdentifier=new_identifier, ApplyImmediately=True
    )
  return new_identifier

def find_old_instance(db_identifier):
  old_db_id = ''
  response = rds.describe_db_instances()
  db_id = f"{db_identifier}-"
  for dbs in response['DBInstances']:
    db = dbs['DBInstanceIdentifier']
    if db.find(db_id) >= 0:
      old_db_id = db
      logger.info(f"Found: {old_db_id}")
      break
  return old_db_id

def load_latest_snapshot(db_identifier, snapshot_arn):
  logger.info(f"Create new: {db_identifier} from {snapshot_arn}")
  inst_class = os.environ.get("DB_INSTANCE_CLASS")
  subnet_group = os.environ.get("DB_SUBNET_GROUP")
  param_group = os.environ.get("DB_PARAMETER_GROUP")
  multi_az = os.environ.get("MULTI_AZ")
  is_multi_az = multi_az.lower() == "true"
  rds.restore_db_instance_from_db_snapshot( DBInstanceIdentifier=db_identifier, DBSnapshotIdentifier=snapshot_arn, DBInstanceClass=inst_class, DBSubnetGroupName=subnet_group, DBParameterGroupName=param_group, MultiAZ=is_multi_az, PubliclyAccessible=False, AutoMinorVersionUpgrade=True, CopyTagsToSnapshot=True,
  )

def modify_security_groups(db_identifier):
  sec_groups = os.environ.get("SECURITY_GROUPS").split(',')
  set_groups = False
  response = rds.describe_db_instances( Filters=[
{'Name': 'db-instance-id', 'Values': [db_identifier]},]
  )
  if response['DBInstances']:
    db_inst = response['DBInstances'][0]
    state = db_inst['DBInstanceStatus']
    cur_sec_groups = db_inst['VpcSecurityGroups']
    for sec_group in cur_sec_groups:
      if sec_group['VpcSecurityGroupId'] not in sec_groups:
        set_groups = True
        break
    if set_groups:
      if state != 'available':
        logger.info(f"{db_identifier} in {state}, not available for modification, waiting")
        try:
          max_wait = int(os.environ.get("MAX_WAIT"))
          max_att = int(max_wait / 5)
          waiter.wait( DBInstanceIdentifier=db_identifier, WaiterConfig={'Delay': 5, 'MaxAttempts': max_att}
          )
          logger.info(f"Set security groups: {sec_groups}")
          rds.modify_db_instance( DBInstanceIdentifier=db_identifier, VpcSecurityGroupIds=sec_groups, ApplyImmediately=True
          )
        except Exception as e:
          logger.warning(e)
    else:
      logger.info("Already correct security groups")
  else:
    logger.warning(f"{db_identifier} not found!")

def delete_old_instance(db_identifier):
  logger.info(f"Deleting old db: {db_identifier}")
  rds.delete_db_instance(DBInstanceIdentifier=db_identifier, SkipFinalSnapshot=True
  )
def lambda_handler(event, context):
  db_identifier = os.environ.get("RDS_DB")
  snapshot_arn = get_latest_snapshot_arn()
  if not snapshot_arn:
    return
  if event and event['detail-type'] == 'RDS DB Instance Event':
    message = event['detail']['Message']
    src_db = event['detail']['SourceIdentifier']

    logger.info(f"{src_db}: {message}")
    if message.find('Renamed instance from') >= 0:
      load_latest_snapshot(db_identifier, snapshot_arn)
    elif message.find('Restored from snapshot') >= 0:
      old_instance = find_old_instance(db_identifier)
      if old_instance:
        delete_old_instance(old_instance)
      else:
        logger.info("Found no old instance to delete")
    elif message.find('Finished updating DB parameter group') >= 0:
      modify_security_groups(db_identifier)
    else:
      logger.info(event)
  elif not event or event['detail-type'] == 'Scheduled Event':
    #Can also be triggered by an empty test event
    old_instance = rename_current_db(db_identifier)
    if not old_instance:
      load_latest_snapshot(db_identifier, snapshot_arn)
This lambda needs a role with the following policy:
{
  "Statement":
  [
    {
      "Action":
      [
        "rds:DescribeDBInstances",
        "rds:CreateDBInstance",
        "rds:DeleteDBInstance",
        "rds:DeleteDBInstanceAutomatedBackup",
        "rds:ModifyDBInstance",
        "rds:StartDBInstance",
        "rds:StopDBInstance",
        "rds:DescribeDBSnapshots",
        "rds:RestoreDBInstanceFromDBSnapshot"
      ],
      "Resource": ["*"],
      "Effect": "Allow"
    }
  ]
}

In the code I’m using string.find(substring) >= 0 since most of the string matching is on sub strings, but I also noticed that some event['detail']['Message'] messages contained a trailing white space.

Link to the original article on Medium.com: 
https://medium.com/purplescout/automatic-sharing-and-loading-rds-snapshots-using-lambda-acbbfa5230ea

Pin It on Pinterest

Share This