{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Importing S3 Data into FinSpace\n",
    "\n",
    "This notebook will show how to use FinSpace APIs to create a dataset and populate it with data from an external (to FinSpace) S3 source.\n",
    "\n",
    "## Preparation\n",
    "Before running the cells below you need to create an S3 bucket, load dataset (in CSV format) into S3 bucket created and then apply Bucket Policy as described below, to allow FinSpace access CSV file on S3.\n",
    "\n",
    "## DataSet\n",
    "ESG News Sentiment Dataset - Trial\n",
    "\n",
    "Provided By: Amenity Analytics \n",
    "\n",
    "https://aws.amazon.com/marketplace/pp/prodview-4doy3qrqm3y6g?ref_=srh_res_product_title\n",
    "\n",
    "Copy data into a bucket that has entitled the FinSpace service account to it. That bucket must grant \n",
    "s3:GetObject and s3:ListBucket actions to the service account ARN.\n",
    "\n",
    "FinSpace Service Account ARN (replace with your environment's service account):   \n",
    "    arn:aws:iam::**INFRASTRUCTURE_ACCOUNT_ID**:role/FinSpaceServiceRole\n",
    "\n",
    "## S3 Bucket Policy to be used\n",
    "\n",
    "- S3 bucket is externally accessible\n",
    "- replace INFRASTRUCTURE_ACCOUNT_ID with your environment's service account\n",
    "- replace S3_BUCKET with your s3 bucket\n",
    "\n",
    "```\n",
    "{\n",
    "    \"Version\": \"2012-10-17\",\n",
    "    \"Id\": \"CrossAccountAccess\",\n",
    "    \"Statement\": [\n",
    "        {\n",
    "            \"Effect\": \"Allow\",\n",
    "            \"Principal\": {\n",
    "                \"AWS\": [\n",
    "                    \"arn:aws:iam::INFRASTRUCTURE_ACCOUNT_ID:role/FinSpaceServiceRole\"\n",
    "                ]\n",
    "            },\n",
    "            \"Action\": \"s3:GetObject\",\n",
    "            \"Resource\": \"arn:aws:s3:::S3_BUCKET/*\"\n",
    "        },\n",
    "        {\n",
    "            \"Effect\": \"Allow\",\n",
    "            \"Principal\": {\n",
    "                \"AWS\": [\n",
    "                    \"arn:aws:iam::INFRASTRUCTURE_ACCOUNT_ID:role/FinSpaceServiceRole\"\n",
    "                ]\n",
    "            },\n",
    "            \"Action\": \"s3:ListBucket\",\n",
    "            \"Resource\": \"arn:aws:s3:::S3_BUCKET\"\n",
    "        }\n",
    "    ]\n",
    "}\n",
    " ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Connecting to cluster - fin-cluster-3d77[ar6oql9k]\n",
      "cleared existing credential location\n",
      "Persisted krb5.conf secret to /etc/krb5.conf\n",
      "re-establishing connection...\n",
      "Persisted keytab secret to /home/sagemaker-user/livy.keytab\n",
      "Authenticated to Spark cluster\n",
      "Persisted sparkmagic config to /home/sagemaker-user/.sparkmagic/config.json\n",
      "Started Spark cluster with clusterId: ar6oql9k\n",
      "finished reloading all magics & configurations\n",
      "Persisted finspace cluster connection info to /home/sagemaker-user/.sparkmagic/finspace_connection_info.json\n"
     ]
    }
   ],
   "source": [
    "%local\n",
    "from aws.finspace.cluster import FinSpaceClusterManager\n",
    "\n",
    "# if this was already run, no need to run again\n",
    "if 'finspace_clusters' not in globals():\n",
    "    finspace_clusters = FinSpaceClusterManager()\n",
    "    finspace_clusters.auto_connect()\n",
    "else:\n",
    "    print(f'connected to cluster: {finspace_clusters.get_connected_cluster_id()}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import Python Utility Classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting Spark application\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<tr><th>ID</th><th>YARN Application ID</th><th>Kind</th><th>State</th><th>Spark UI</th><th>Driver log</th><th>Current session?</th></tr><tr><td>2</td><td>application_1628754622409_0003</td><td>pyspark</td><td>idle</td><td><a target=\"_blank\" href=\"http://ip-192-168-39-223.ec2.internal:20888/proxy/application_1628754622409_0003/\">Link</a></td><td><a target=\"_blank\" href=\"http://ip-192-168-46-239.ec2.internal:8042/node/containerlogs/container_1628754622409_0003_01_000001/livy\">Link</a></td><td>✔</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SparkSession available as 'spark'.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# %load ../Utilities/finspace.py\n",
    "import datetime\n",
    "import time\n",
    "import boto3\n",
    "import os\n",
    "import pandas as pd\n",
    "import urllib\n",
    "\n",
    "from urllib.parse import urlparse\n",
    "from botocore.config import Config\n",
    "from boto3.session import Session\n",
    "\n",
    "\n",
    "# Base FinSpace class\n",
    "class FinSpace:\n",
    "\n",
    "    def __init__(\n",
    "            self,\n",
    "            config=Config(retries={'max_attempts': 3, 'mode': 'standard'}),\n",
    "            boto_session: Session = None,\n",
    "            dev_overrides: dict = None,\n",
    "            service_name = 'finspace-data'):\n",
    "        \"\"\"\n",
    "        To configure this class object, simply instantiate with no-arg if hitting prod endpoint, or else override it:\n",
    "        e.g.\n",
    "           `hab = FinSpaceAnalyticsManager(region_name = 'us-east-1',\n",
    "           dev_overrides = {'hfs_endpoint': 'https://39g32x40jk.execute-api.us-east-1.amazonaws.com/alpha'})`\n",
    "        \"\"\"\n",
    "        self.hfs_endpoint = None\n",
    "        self.region_name = None\n",
    "\n",
    "        if dev_overrides is not None:\n",
    "            if 'hfs_endpoint' in dev_overrides:\n",
    "                self.hfs_endpoint = dev_overrides['hfs_endpoint']\n",
    "\n",
    "            if 'region_name' in dev_overrides:\n",
    "                self.region_name = dev_overrides['region_name']\n",
    "        else:\n",
    "            if boto_session is not None:\n",
    "                self.region_name = boto_session.region_name\n",
    "            else:\n",
    "                self.region_name = self.get_region_name()\n",
    "\n",
    "        self.config = config\n",
    "\n",
    "        self._boto3_session = boto3.session.Session(region_name=self.region_name) if boto_session is None else boto_session\n",
    "\n",
    "        print(f\"service_name: {service_name}\")\n",
    "        print(f\"endpoint: {self.hfs_endpoint}\")\n",
    "        print(f\"region_name: {self.region_name}\")\n",
    "\n",
    "        self.client = self._boto3_session.client(service_name, endpoint_url=self.hfs_endpoint, config=self.config)\n",
    "\n",
    "    @staticmethod\n",
    "    def get_region_name():\n",
    "        req = urllib.request.Request(\"http://169.254.169.254/latest/meta-data/placement/region\")\n",
    "        with urllib.request.urlopen(req) as response:\n",
    "            return response.read().decode(\"utf-8\")\n",
    "\n",
    "    # --------------------------------------\n",
    "    # Utility Functions\n",
    "    # --------------------------------------\n",
    "    @staticmethod\n",
    "    def get_list(all_list: dir, name: str):\n",
    "        \"\"\"\n",
    "        Search for name found in the all_list dir and return that list of things.\n",
    "        Removes repetitive code found in functions that call boto apis then search for the expected returned items\n",
    "\n",
    "        :param all_list: list of things to search\n",
    "        :type: dir:\n",
    "\n",
    "        :param name: name to search for in all_lists\n",
    "        :type: str\n",
    "\n",
    "        :return: list of items found in name\n",
    "        \"\"\"\n",
    "        r = []\n",
    "\n",
    "        # is the given name found, is found, add to list\n",
    "        if name in all_list:\n",
    "            for s in all_list[name]:\n",
    "                r.append(s)\n",
    "\n",
    "        # return the list\n",
    "        return r\n",
    "\n",
    "    # --------------------------------------\n",
    "    # Classification Functions\n",
    "    # --------------------------------------\n",
    "\n",
    "    def list_classifications(self):\n",
    "        \"\"\"\n",
    "        Return list of all classifications\n",
    "\n",
    "        :return: all classifications\n",
    "        \"\"\"\n",
    "        all_list = self.client.list_classifications(sort='NAME')\n",
    "\n",
    "        return self.get_list(all_list, 'classifications')\n",
    "\n",
    "    def classification_names(self):\n",
    "        \"\"\"\n",
    "        Get the classifications names\n",
    "\n",
    "        :return list of classifications names only\n",
    "        \"\"\"\n",
    "        classification_names = []\n",
    "        all_classifications = self.list_classifications()\n",
    "        for c in all_classifications:\n",
    "            classification_names.append(c['name'])\n",
    "        return classification_names\n",
    "\n",
    "    def classification(self, name: str):\n",
    "        \"\"\"\n",
    "        Exact name search for a classification of the given name\n",
    "\n",
    "        :param name: name of the classification to find\n",
    "        :type: str\n",
    "\n",
    "        :return\n",
    "        \"\"\"\n",
    "\n",
    "        all_classifications = self.list_classifications()\n",
    "        existing_classification = next((c for c in all_classifications if c['name'].lower() == name.lower()), None)\n",
    "        if existing_classification:\n",
    "            return existing_classification\n",
    "\n",
    "    def describe_classification(self, classification_id: str):\n",
    "        \"\"\"\n",
    "        Calls the describe classification API function and only returns the taxonomy portion of the response.\n",
    "\n",
    "        :param classification_id: the GUID of the classification to get description of\n",
    "        :type: str\n",
    "        \"\"\"\n",
    "        resp = None\n",
    "        taxonomy_details_resp = self.client.describe_taxonomy(taxonomyId=classification_id)\n",
    "\n",
    "        if 'taxonomy' in taxonomy_details_resp:\n",
    "            resp = taxonomy_details_resp['taxonomy']\n",
    "\n",
    "        return (resp)\n",
    "\n",
    "    def create_classification(self, classification_definition):\n",
    "        resp = self.client.create_taxonomy(taxonomyDefinition=classification_definition)\n",
    "\n",
    "        taxonomy_id = resp[\"taxonomyId\"]\n",
    "\n",
    "        return (taxonomy_id)\n",
    "\n",
    "    def delete_classification(self, classification_id):\n",
    "        resp = self.client.delete_taxonomy(taxonomyId=classification_id)\n",
    "\n",
    "        if resp['ResponseMetadata']['HTTPStatusCode'] != 200:\n",
    "            return resp\n",
    "\n",
    "        return True\n",
    "\n",
    "    # --------------------------------------\n",
    "    # Attribute Set Functions\n",
    "    # --------------------------------------\n",
    "\n",
    "    def list_attribute_sets(self):\n",
    "        \"\"\"\n",
    "        Get list of all dataset_types in the system\n",
    "\n",
    "        :return: list of dataset types\n",
    "        \"\"\"\n",
    "        resp = self.client.list_dataset_types()\n",
    "        results = resp['datasetTypeSummaries']\n",
    "\n",
    "        while \"nextToken\" in resp:\n",
    "            resp = self.client.list_dataset_types(nextToken=resp['nextToken'])\n",
    "            results.extend(resp['datasetTypeSummaries'])\n",
    "\n",
    "        return (results)\n",
    "\n",
    "    def attribute_set_names(self):\n",
    "        \"\"\"\n",
    "        Get the list of all dataset type names\n",
    "\n",
    "        :return list of all dataset type names\n",
    "        \"\"\"\n",
    "\n",
    "        dataset_type_names = []\n",
    "        all_dataset_types = self.list_dataset_types()\n",
    "        for c in all_dataset_types:\n",
    "            dataset_type_names.append(c['name'])\n",
    "        return dataset_type_names\n",
    "\n",
    "    def attribute_set(self, name: str):\n",
    "        \"\"\"\n",
    "        Exact name search for a dataset type of the given name\n",
    "\n",
    "        :param name: name of the dataset type to find\n",
    "        :type: str\n",
    "\n",
    "        :return\n",
    "        \"\"\"\n",
    "\n",
    "        all_dataset_types = self.list_dataset_types()\n",
    "        existing_dataset_type = next((c for c in all_dataset_types if c['name'].lower() == name.lower()), None)\n",
    "        if existing_dataset_type:\n",
    "            return existing_dataset_type\n",
    "\n",
    "    def describe_attribute_set(self, attribute_set_id: str):\n",
    "        \"\"\"\n",
    "        Calls the describe dataset type API function and only returns the dataset type portion of the response.\n",
    "\n",
    "        :param attribute_set_id: the GUID of the dataset type to get description of\n",
    "        :type: str\n",
    "        \"\"\"\n",
    "        resp = None\n",
    "        dataset_type_details_resp = self.client.describe_dataset_type(datasetTypeId=attribute_set_id)\n",
    "\n",
    "        if 'datasetType' in dataset_type_details_resp:\n",
    "            resp = dataset_type_details_resp['datasetType']\n",
    "\n",
    "        return (resp)\n",
    "\n",
    "    def create_attribute_set(self, attribute_set_def):\n",
    "        resp = self.client.create_dataset_type(datasetTypeDefinition=attribute_set_def)\n",
    "\n",
    "        att_id = resp[\"datasetTypeId\"]\n",
    "\n",
    "        return (att_id)\n",
    "\n",
    "    def delete_attribute_set(self, attribute_set_id: str):\n",
    "        resp = self.client.delete_attribute_set(attributeSetId=attribute_set_id)\n",
    "\n",
    "        if resp['ResponseMetadata']['HTTPStatusCode'] != 200:\n",
    "            return resp\n",
    "\n",
    "        return True\n",
    "\n",
    "    def associate_attribute_set(self, att_name: str, att_values: list, dataset_id: str):\n",
    "        # get the attribute set by name, will need its id\n",
    "        att_set = self.attribute_set(att_name)\n",
    "\n",
    "        # get the dataset's information, will need the arn\n",
    "        dataset = self.describe_dataset_details(dataset_id=dataset_id)\n",
    "\n",
    "        # disassociate any existing relationship\n",
    "        try:\n",
    "            self.client.dissociate_dataset_from_dataset_type(datasetArn=dataset['arn'],\n",
    "                                                                       datasetTypeId=att_set['id'])\n",
    "        except:\n",
    "            print(\"Nothing to disassociate\")\n",
    "\n",
    "        self.client.associate_dataset_with_dataset_type(datasetArn=dataset['arn'], datasetTypeId=att_set['id'])\n",
    "\n",
    "        ret = self.client.update_dataset_type_context(datasetArn=dataset['arn'], datasetTypeId=att_set['id'],\n",
    "                                                          values=att_values)\n",
    "        return ret\n",
    "\n",
    "    # --------------------------------------\n",
    "    # Permission Group Functions\n",
    "    # --------------------------------------\n",
    "\n",
    "    def list_permission_groups(self, max_results: int):\n",
    "        all_perms = self.client.list_permission_groups(MaxResults=max_results)\n",
    "        return (self.get_list(all_perms, 'permissionGroups'))\n",
    "\n",
    "    def permission_group(self, name):\n",
    "        all_groups = self.list_permission_groups(max_results = 100)\n",
    "\n",
    "        existing_group = next((c for c in all_groups if c['name'].lower() == name.lower()), None)\n",
    "\n",
    "        if existing_group:\n",
    "            return existing_group\n",
    "\n",
    "    def describe_permission_group(self, permission_group_id: str):\n",
    "        resp = None\n",
    "\n",
    "        perm_resp = self.client.describe_permission_group(permissionGroupId=permission_group_id)\n",
    "\n",
    "        if 'permissionGroup' in perm_resp:\n",
    "            resp = perm_resp['permissionGroup']\n",
    "\n",
    "        return (resp)\n",
    "\n",
    "    # --------------------------------------\n",
    "    # Dataset Functions\n",
    "    # --------------------------------------\n",
    "\n",
    "    def describe_dataset_details(self, dataset_id: str):\n",
    "        \"\"\"\n",
    "        Calls the describe dataset details API function and only returns the dataset details portion of the response.\n",
    "\n",
    "        :param dataset_id: the GUID of the dataset to get description of\n",
    "        :type: str\n",
    "        \"\"\"\n",
    "        resp = None\n",
    "        dataset_details_resp = self.client.describe_dataset_details(datasetId=dataset_id)\n",
    "\n",
    "        if 'dataset' in dataset_details_resp:\n",
    "            resp = dataset_details_resp[\"dataset\"]\n",
    "\n",
    "        return (resp)\n",
    "\n",
    "    def create_dataset(self, name: str, description: str, permission_group_id: str, dataset_permissions: [], kind: str,\n",
    "                       owner_info, schema):\n",
    "        \"\"\"\n",
    "        Create a dataset\n",
    "\n",
    "        Warning, dataset names are not unique, be sure to check for the same name dataset before creating a new one\n",
    "\n",
    "        :param name: Name of the dataset\n",
    "        :type: str\n",
    "\n",
    "        :param description: Description of the dataset\n",
    "        :type: str\n",
    "\n",
    "        :param permission_group_id: permission group for the dataset\n",
    "        :type: str\n",
    "\n",
    "        :param dataset_permissions: permissions for the group on the dataset\n",
    "\n",
    "        :param kind: Kind of dataset, choices: TABULAR\n",
    "        :type: str\n",
    "\n",
    "        :param owner_info: owner information for the dataset\n",
    "\n",
    "        :param schema: Schema of the dataset\n",
    "\n",
    "        :return: the dataset_id of the created dataset\n",
    "        \"\"\"\n",
    "\n",
    "        if dataset_permissions:\n",
    "            request_dataset_permissions = [{\"permission\": permissionName} for permissionName in dataset_permissions]\n",
    "        else:\n",
    "            request_dataset_permissions = []\n",
    "\n",
    "        response = self.client.create_dataset(name=name,\n",
    "                                              permissionGroupId = permission_group_id,\n",
    "                                              datasetPermissions = request_dataset_permissions,\n",
    "                                              kind=kind,\n",
    "                                              description = description.replace('\\n', ' '),\n",
    "                                              ownerInfo = owner_info,\n",
    "                                              schema = schema)\n",
    "\n",
    "        return response[\"datasetId\"]\n",
    "\n",
    "    def ingest_from_s3(self,\n",
    "                       s3_location: str,\n",
    "                       dataset_id: str,\n",
    "                       change_type: str,\n",
    "                       wait_for_completion: bool = True,\n",
    "                       format_type: str = \"CSV\",\n",
    "                       format_params: dict = {'separator': ',', 'withHeader': 'true'}):\n",
    "        \"\"\"\n",
    "        Creates a changeset and ingests the data given in the S3 location into the changeset\n",
    "\n",
    "        :param s3_location: the source location of the data for the changeset, will be copied into the changeset\n",
    "        :stype: str\n",
    "\n",
    "        :param dataset_id: the identifier of the containing dataset for the changeset to be created for this data\n",
    "        :type: str\n",
    "\n",
    "        :param change_type: What is the kind of changetype?  \"APPEND\", \"REPLACE\" are the choices\n",
    "        :type: str\n",
    "\n",
    "        :param wait_for_completion: Boolean, should the function wait for the operation to complete?\n",
    "        :type: str\n",
    "\n",
    "        :param format_type: format type, CSV, PARQUET, XML, JSON\n",
    "        :type: str\n",
    "\n",
    "        :param format_params: dictionary of format parameters\n",
    "        :type: dict\n",
    "\n",
    "        :return: the id of the changeset created\n",
    "        \"\"\"\n",
    "        create_changeset_response = self.client.create_changeset(\n",
    "            datasetId=dataset_id,\n",
    "            changeType=change_type,\n",
    "            sourceType='S3',\n",
    "            sourceParams={'s3SourcePath': s3_location},\n",
    "            formatType=format_type.upper(),\n",
    "            formatParams=format_params\n",
    "        )\n",
    "\n",
    "        changeset_id = create_changeset_response['changeset']['id']\n",
    "\n",
    "        if wait_for_completion:\n",
    "            self.wait_for_ingestion(dataset_id, changeset_id)\n",
    "        return changeset_id\n",
    "\n",
    "    def describe_changeset(self, dataset_id: str, changeset_id: str):\n",
    "        \"\"\"\n",
    "        Function to get a description of the the givn changeset for the given dataset\n",
    "\n",
    "        :param dataset_id: identifier of the dataset\n",
    "        :type: str\n",
    "\n",
    "        :param changeset_id: the idenfitier of the changeset\n",
    "        :type: str\n",
    "\n",
    "        :return: all information about the changeset, if found\n",
    "        \"\"\"\n",
    "        describe_changeset_resp = self.client.describe_changeset(datasetId=dataset_id, id=changeset_id)\n",
    "\n",
    "        return describe_changeset_resp['changeset']\n",
    "\n",
    "    def create_as_of_view(self, dataset_id: str, as_of_date: datetime, destination_type: str,\n",
    "                          partition_columns: list = [], sort_columns: list = [], destination_properties: dict = {},\n",
    "                          wait_for_completion: bool = True):\n",
    "        \"\"\"\n",
    "        Creates an 'as of' static view up to and including the requested 'as of' date provided.\n",
    "\n",
    "        :param dataset_id: identifier of the dataset\n",
    "        :type: str\n",
    "\n",
    "        :param as_of_date: as of date, will include changesets up to this date/time in the view\n",
    "        :type: datetime\n",
    "\n",
    "        :param destination_type: destination type\n",
    "        :type: str\n",
    "\n",
    "        :param partition_columns: columns to partition the data by for the created view\n",
    "        :type: list\n",
    "\n",
    "        :param sort_columns: column to sort the view by\n",
    "        :type: list\n",
    "\n",
    "        :param destination_properties: destination properties\n",
    "        :type: dict\n",
    "\n",
    "        :param wait_for_completion: should the function wait for the system to create the view?\n",
    "        :type: bool\n",
    "\n",
    "        :return str: GUID of the created view if successful\n",
    "\n",
    "        \"\"\"\n",
    "        create_materialized_view_resp = self.client.create_materialized_snapshot(\n",
    "            datasetId=dataset_id,\n",
    "            asOfTimestamp=as_of_date,\n",
    "            destinationType=destination_type,\n",
    "            partitionColumns=partition_columns,\n",
    "            sortColumns=sort_columns,\n",
    "            autoUpdate=False,\n",
    "            destinationProperties=destination_properties\n",
    "        )\n",
    "        view_id = create_materialized_view_resp['id']\n",
    "        if wait_for_completion:\n",
    "            self.wait_for_view(dataset_id=dataset_id, view_id=view_id)\n",
    "        return view_id\n",
    "\n",
    "    def create_auto_update_view(self, dataset_id: str, destination_type: str,\n",
    "                                partition_columns=[], sort_columns=[], destination_properties={},\n",
    "                                wait_for_completion=True):\n",
    "        \"\"\"\n",
    "        Creates an auto-updating view of the given dataset\n",
    "\n",
    "        :param dataset_id: identifier of the dataset\n",
    "        :type: str\n",
    "\n",
    "        :param destination_type: destination type\n",
    "        :type: str\n",
    "\n",
    "        :param partition_columns: columns to partition the data by for the created view\n",
    "        :type: list\n",
    "\n",
    "        :param sort_columns: column to sort the view by\n",
    "        :type: list\n",
    "\n",
    "        :param destination_properties: destination properties\n",
    "        :type: str\n",
    "\n",
    "        :param wait_for_completion: should the function wait for the system to create the view?\n",
    "        :type: bool\n",
    "\n",
    "        :return str: GUID of the created view if successful\n",
    "\n",
    "        \"\"\"\n",
    "        create_materialized_view_resp = self.client.create_materialized_snapshot(\n",
    "            datasetId=dataset_id,\n",
    "            destinationType=destination_type,\n",
    "            partitionColumns=partition_columns,\n",
    "            sortColumns=sort_columns,\n",
    "            autoUpdate=True,\n",
    "            destinationProperties=destination_properties\n",
    "        )\n",
    "        view_id = create_materialized_view_resp['id']\n",
    "        if wait_for_completion:\n",
    "            self.wait_for_view(dataset_id=dataset_id, view_id=view_id)\n",
    "        return view_id\n",
    "\n",
    "    def wait_for_ingestion(self, dataset_id: str, changeset_id: str, sleep_sec=10):\n",
    "        \"\"\"\n",
    "        function that will continuously poll the changeset creation to ensure it completes or fails before returning.\n",
    "\n",
    "        :param dataset_id: GUID of the dataset\n",
    "        :type: str\n",
    "\n",
    "        :param changeset_id: GUID of the changeset\n",
    "        :type: str\n",
    "\n",
    "        :param sleep_sec: seconds to wait between checks\n",
    "        :type: int\n",
    "\n",
    "        \"\"\"\n",
    "        while True:\n",
    "            status = self.describe_changeset(dataset_id=dataset_id, changeset_id=changeset_id)['status']\n",
    "            if status == 'SUCCESS':\n",
    "                print(f\"Changeset complete\")\n",
    "                break\n",
    "            elif status == 'PENDING' or status == 'RUNNING':\n",
    "                print(f\"Changeset status is still PENDING, waiting {sleep_sec} sec ...\")\n",
    "                time.sleep(sleep_sec)\n",
    "                continue\n",
    "            else:\n",
    "                raise Exception(f\"Bad changeset status: {status}, failing now.\")\n",
    "\n",
    "    def wait_for_view(self, dataset_id: str, view_id: str, sleep_sec=10):\n",
    "        \"\"\"\n",
    "        function that will continuously poll the view creation to ensure it completes or fails before returning.\n",
    "\n",
    "        :param dataset_id: GUID of the dataset\n",
    "        :type: str\n",
    "\n",
    "        :param view_id: GUID of the view\n",
    "        :type: str\n",
    "\n",
    "        :param sleep_sec: seconds to wait between checks\n",
    "        :type: int\n",
    "\n",
    "        \"\"\"\n",
    "        while True:\n",
    "            list_views_resp = self.client.list_materialization_snapshots(datasetId=dataset_id, maxResults=100)\n",
    "            matched_views = list(filter(lambda d: d['id'] == view_id, list_views_resp['materializationSnapshots']))\n",
    "\n",
    "            if len(matched_views) != 1:\n",
    "                size = len(matched_views)\n",
    "                raise Exception(f\"Unexpected error: found {size} views that match the view Id: {view_id}\")\n",
    "\n",
    "            status = matched_views[0]['status']\n",
    "            if status == 'SUCCESS':\n",
    "                print(f\"View complete\")\n",
    "                break\n",
    "            elif status == 'PENDING' or status == 'RUNNING':\n",
    "                print(f\"View status is still PENDING, continue to wait till finish...\")\n",
    "                time.sleep(sleep_sec)\n",
    "                continue\n",
    "            else:\n",
    "                raise Exception(f\"Bad view status: {status}, failing now.\")\n",
    "\n",
    "    def list_changesets(self, dataset_id: str):\n",
    "        resp = self.client.list_changesets(datasetId=dataset_id, sortKey='CREATE_TIMESTAMP')\n",
    "        results = resp['changesets']\n",
    "\n",
    "        while \"nextToken\" in resp:\n",
    "            resp = self.client.list_changesets(datasetId=dataset_id, sortKey='CREATE_TIMESTAMP',\n",
    "                                               nextToken=resp['nextToken'])\n",
    "            results.extend(resp['changesets'])\n",
    "\n",
    "        return (results)\n",
    "\n",
    "    def list_views(self, dataset_id: str, max_results=50):\n",
    "        resp = self.client.list_materialization_snapshots(datasetId=dataset_id, maxResults=max_results)\n",
    "        results = resp['materializationSnapshots']\n",
    "\n",
    "        while \"nextToken\" in resp:\n",
    "            resp = self.client.list_materialization_snapshots(datasetId=dataset_id, maxResults=max_results,\n",
    "                                                              nextToken=resp['nextToken'])\n",
    "            results.extend(resp['materializationSnapshots'])\n",
    "\n",
    "        return (results)\n",
    "\n",
    "    def list_datasets(self, max_results: int):\n",
    "        all_datasets = self.client.list_datasets(maxResults=max_results)\n",
    "        return (self.get_list(all_datasets, 'datasets'))\n",
    "\n",
    "    def list_dataset_types(self):\n",
    "        resp = self.client.list_dataset_types(sort='NAME')\n",
    "        results = resp['datasetTypeSummaries']\n",
    "\n",
    "        while \"nextToken\" in resp:\n",
    "            resp = self.client.list_dataset_types(sort='NAME', nextToken=resp['nextToken'])\n",
    "            results.extend(resp['datasetTypeSummaries'])\n",
    "\n",
    "        return (results)\n",
    "\n",
    "    @staticmethod\n",
    "    def get_execution_role():\n",
    "        \"\"\"\n",
    "        Convenience function from SageMaker to get the execution role of the user of the sagemaker studio notebook\n",
    "\n",
    "        :return: the ARN of the execution role in the sagemaker studio notebook\n",
    "        \"\"\"\n",
    "        import sagemaker as sm\n",
    "\n",
    "        e_role = sm.get_execution_role()\n",
    "        return (f\"{e_role}\")\n",
    "\n",
    "    def get_user_ingestion_info(self):\n",
    "        return (self.client.get_user_ingestion_info())\n",
    "\n",
    "    def upload_pandas(self, data_frame: pd.DataFrame):\n",
    "        import awswrangler as wr\n",
    "        resp = self.client.get_working_location(locationType='INGESTION')\n",
    "        upload_location = resp['s3Uri']\n",
    "        wr.s3.to_parquet(data_frame, f\"{upload_location}data.parquet\", index=False, boto3_session=self._boto3_session)\n",
    "        return upload_location\n",
    "\n",
    "    def ingest_pandas(self, data_frame: pd.DataFrame, dataset_id: str, change_type: str, wait_for_completion=True):\n",
    "        print(\"Uploading the pandas dataframe ...\")\n",
    "        upload_location = self.upload_pandas(data_frame)\n",
    "\n",
    "        print(\"Data upload finished. Ingesting data ...\")\n",
    "        return self.ingest_from_s3(upload_location, dataset_id, change_type, wait_for_completion, format_type='PARQUET')\n",
    "\n",
    "    def read_view_as_pandas(self, dataset_id: str, view_id: str):\n",
    "        \"\"\"\n",
    "        Returns a pandas dataframe the view of the given dataset.  Views in FinSpace can be quite large, be careful!\n",
    "\n",
    "        :param dataset_id:\n",
    "        :param view_id:\n",
    "\n",
    "        :return: Pandas dataframe with all data of the view\n",
    "        \"\"\"\n",
    "        import awswrangler as wr  # use awswrangler to read the table\n",
    "\n",
    "        # @todo: switch to DescribeMateriliazation when available in HFS\n",
    "        views = self.list_views(dataset_id=dataset_id, max_results=50)\n",
    "        filtered = [v for v in views if v['id'] == view_id]\n",
    "\n",
    "        if len(filtered) == 0:\n",
    "            raise Exception('No such view found')\n",
    "        if len(filtered) > 1:\n",
    "            raise Exception('Internal Server error')\n",
    "        view = filtered[0]\n",
    "\n",
    "        # 0. Ensure view is ready to be read\n",
    "        if (view['status'] != 'SUCCESS'):\n",
    "            status = view['status']\n",
    "            print(f'view run status is not ready: {status}. Returning empty.')\n",
    "            return\n",
    "\n",
    "        glue_db_name = view['destinationTypeProperties']['databaseName']\n",
    "        glue_table_name = view['destinationTypeProperties']['tableName']\n",
    "\n",
    "        # determine if the table has partitions first, different way to read is there are partitions\n",
    "        p = wr.catalog.get_partitions(table=glue_table_name, database=glue_db_name, boto3_session=self._boto3_session)\n",
    "\n",
    "        def no_filter(partitions):\n",
    "            if len(partitions.keys()) > 0:\n",
    "                return True\n",
    "\n",
    "            return False\n",
    "\n",
    "        df = None\n",
    "\n",
    "        if len(p) == 0:\n",
    "            df = wr.s3.read_parquet_table(table=glue_table_name, database=glue_db_name,\n",
    "                                          boto3_session=self._boto3_session)\n",
    "        else:\n",
    "            spath = wr.catalog.get_table_location(table=glue_table_name, database=glue_db_name,\n",
    "                                                  boto3_session=self._boto3_session)\n",
    "            cpath = wr.s3.list_directories(f\"{spath}/*\", boto3_session=self._boto3_session)\n",
    "\n",
    "            read_path = f\"{spath}/\"\n",
    "\n",
    "            # just one?  Read it\n",
    "            if len(cpath) == 1:\n",
    "                read_path = cpath[0]\n",
    "\n",
    "            df = wr.s3.read_parquet(read_path, dataset=True, partition_filter=no_filter,\n",
    "                                    boto3_session=self._boto3_session)\n",
    "\n",
    "        # Query Glue table directly with wrangler\n",
    "        return df\n",
    "\n",
    "    @staticmethod\n",
    "    def get_schema_from_pandas(df: pd.DataFrame):\n",
    "        \"\"\"\n",
    "        Returns the FinSpace schema columns from the given pandas dataframe.\n",
    "\n",
    "        :param df: pandas dataframe to interrogate for the schema\n",
    "\n",
    "        :return: FinSpace column schema list\n",
    "        \"\"\"\n",
    "\n",
    "        # for translation to FinSpace's schema\n",
    "        # 'STRING'|'CHAR'|'INTEGER'|'TINYINT'|'SMALLINT'|'BIGINT'|'FLOAT'|'DOUBLE'|'DATE'|'DATETIME'|'BOOLEAN'|'BINARY'\n",
    "        DoubleType = \"DOUBLE\"\n",
    "        FloatType = \"FLOAT\"\n",
    "        DateType = \"DATE\"\n",
    "        StringType = \"STRING\"\n",
    "        IntegerType = \"INTEGER\"\n",
    "        LongType = \"BIGINT\"\n",
    "        BooleanType = \"BOOLEAN\"\n",
    "        TimestampType = \"DATETIME\"\n",
    "\n",
    "        hab_columns = []\n",
    "\n",
    "        for name in dict(df.dtypes):\n",
    "            p_type = df.dtypes[name]\n",
    "\n",
    "            switcher = {\n",
    "                \"float64\": DoubleType,\n",
    "                \"int64\": IntegerType,\n",
    "                \"datetime64[ns, UTC]\": TimestampType,\n",
    "                \"datetime64[ns]\": DateType\n",
    "            }\n",
    "\n",
    "            habType = switcher.get(str(p_type), StringType)\n",
    "\n",
    "            hab_columns.append({\n",
    "                \"dataType\": habType,\n",
    "                \"name\": name,\n",
    "                \"description\": \"\"\n",
    "            })\n",
    "\n",
    "        return (hab_columns)\n",
    "\n",
    "    @staticmethod\n",
    "    def get_date_cols(df: pd.DataFrame):\n",
    "        \"\"\"\n",
    "        Returns which are the data columns found in the pandas dataframe.\n",
    "        Pandas does the hard work to figure out which of the columns can be considered to be date columns.\n",
    "\n",
    "        :param df: pandas dataframe to interrogate for the schema\n",
    "\n",
    "        :return: list of column names that can be parsed as dates by pandas\n",
    "\n",
    "        \"\"\"\n",
    "        date_cols = []\n",
    "\n",
    "        for name in dict(df.dtypes):\n",
    "\n",
    "            p_type = df.dtypes[name]\n",
    "\n",
    "            if str(p_type).startswith(\"date\"):\n",
    "                date_cols.append(name)\n",
    "\n",
    "        return (date_cols)\n",
    "\n",
    "    def get_best_schema_from_csv(self, path, is_s3=True, read_rows=500, sep=','):\n",
    "        \"\"\"\n",
    "        Uses multiple reads of the file with pandas to determine schema of the referenced files.\n",
    "        Files are expected to be csv.\n",
    "\n",
    "        :param path: path to the files to read\n",
    "        :type: str\n",
    "\n",
    "        :param is_s3: True if the path is s3;  False if filesystem\n",
    "        :type: bool\n",
    "\n",
    "        :param read_rows: number of rows to sample for determining schema\n",
    "\n",
    "        :param sep:\n",
    "\n",
    "        :return dict: schema for FinSpace\n",
    "        \"\"\"\n",
    "        #\n",
    "        # best efforts to determine the schema, sight unseen\n",
    "        import awswrangler as wr\n",
    "\n",
    "        # 1: get the base schema\n",
    "        df1 = None\n",
    "\n",
    "        if is_s3:\n",
    "            df1 = wr.s3.read_csv(path, nrows=read_rows, sep=sep)\n",
    "        else:\n",
    "            df1 = pd.read_csv(path, nrows=read_rows, sep=sep)\n",
    "\n",
    "        num_cols = len(df1.columns)\n",
    "\n",
    "        # with number of columns, try to infer dates\n",
    "        df2 = None\n",
    "\n",
    "        if is_s3:\n",
    "            df2 = wr.s3.read_csv(path, parse_dates=list(range(0, num_cols)), infer_datetime_format=True,\n",
    "                                 nrows=read_rows, sep=sep)\n",
    "        else:\n",
    "            df2 = pd.read_csv(path, parse_dates=list(range(0, num_cols)), infer_datetime_format=True, nrows=read_rows,\n",
    "                              sep=sep)\n",
    "\n",
    "        date_cols = self.get_date_cols(df2)\n",
    "\n",
    "        # with dates known, parse the file fully\n",
    "        df = None\n",
    "\n",
    "        if is_s3:\n",
    "            df = wr.s3.read_csv(path, parse_dates=date_cols, infer_datetime_format=True, nrows=read_rows, sep=sep)\n",
    "        else:\n",
    "            df = pd.read_csv(path, parse_dates=date_cols, infer_datetime_format=True, nrows=read_rows, sep=sep)\n",
    "\n",
    "        schema_cols = self.get_schema_from_pandas(df)\n",
    "\n",
    "        return (schema_cols)\n",
    "\n",
    "    def s3_upload_file(self, source_file: str, s3_destination: str):\n",
    "        \"\"\"\n",
    "        Uploads a local file (full path) to the s3 destination given (expected form: s3://<bucket>/<prefix>/).\n",
    "        The filename will have spaces replaced with _.\n",
    "\n",
    "        :param source_file: path of file to upload\n",
    "        :param s3_destination: full path to where to save the file\n",
    "        :type: str\n",
    "\n",
    "        \"\"\"\n",
    "        hab_s3_client = self._boto3_session.client(service_name='s3')\n",
    "\n",
    "        o = urlparse(s3_destination)\n",
    "        bucket = o.netloc\n",
    "        prefix = o.path.lstrip('/')\n",
    "\n",
    "        fname = os.path.basename(source_file)\n",
    "\n",
    "        hab_s3_client.upload_file(source_file, bucket, f\"{prefix}{fname.replace(' ', '_')}\")\n",
    "\n",
    "    def list_objects(self, s3_location: str):\n",
    "        \"\"\"\n",
    "        lists the objects found at the s3_location. Strips out the boto API response header,\n",
    "        just returns the contents of the location. Internally uses the list_objects_v2.\n",
    "\n",
    "        :param s3_location: path, starting with s3:// to get the list of objects from\n",
    "        :type: str\n",
    "\n",
    "        \"\"\"\n",
    "        o = urlparse(s3_location)\n",
    "        bucket = o.netloc\n",
    "        prefix = o.path.lstrip('/')\n",
    "\n",
    "        results = []\n",
    "\n",
    "        hab_s3_client = self._boto3_session.client(service_name='s3')\n",
    "\n",
    "        paginator = hab_s3_client.get_paginator('list_objects_v2')\n",
    "        pages = paginator.paginate(Bucket=bucket, Prefix=prefix)\n",
    "\n",
    "        for page in pages:\n",
    "            if 'Contents' in page:\n",
    "                results.extend(page['Contents'])\n",
    "\n",
    "        return (results)\n",
    "\n",
    "    def list_clusters(self, status: str = None):\n",
    "        \"\"\"\n",
    "        Lists current clusters and their statuses\n",
    "\n",
    "        :param status: status to filter for\n",
    "\n",
    "        :return dict: list of clusters\n",
    "        \"\"\"\n",
    "\n",
    "        resp = self.client.list_clusters()\n",
    "\n",
    "        clusters = []\n",
    "\n",
    "        if 'clusters' not in resp:\n",
    "            return (clusters)\n",
    "\n",
    "        for c in resp['clusters']:\n",
    "            if status is None:\n",
    "                clusters.append(c)\n",
    "            else:\n",
    "                if c['clusterStatus']['state'] in status:\n",
    "                    clusters.append(c)\n",
    "\n",
    "        return (clusters)\n",
    "\n",
    "    def get_cluster(self, cluster_id):\n",
    "        \"\"\"\n",
    "        Resize the given cluster to desired template\n",
    "\n",
    "        :param cluster_id: cluster id\n",
    "        \"\"\"\n",
    "\n",
    "        clusters = self.list_clusters()\n",
    "\n",
    "        for c in clusters:\n",
    "            if c['clusterId'] == cluster_id:\n",
    "                return (c)\n",
    "\n",
    "        return (None)\n",
    "\n",
    "    def update_cluster(self, cluster_id: str, template: str):\n",
    "        \"\"\"\n",
    "        Resize the given cluster to desired template\n",
    "\n",
    "        :param cluster_id: cluster id\n",
    "        :param template: target template to resize to\n",
    "        \"\"\"\n",
    "\n",
    "        cluster = self.get_cluster(cluster_id=cluster_id)\n",
    "\n",
    "        if cluster['currentTemplate'] == template:\n",
    "            print(f\"Already using template: {template}\")\n",
    "            return (cluster)\n",
    "\n",
    "        self.client.update_cluster(clusterId=cluster_id, template=template)\n",
    "\n",
    "        return (self.get_cluster(cluster_id=cluster_id))\n",
    "\n",
    "    def wait_for_status(self, clusterId: str, status: str, sleep_sec=10, max_wait_sec=900):\n",
    "        \"\"\"\n",
    "        Function polls service until cluster is in desired status.\n",
    "\n",
    "        :param clusterId: the cluster's ID\n",
    "        :param status: desired status for clsuter to reach\n",
    "        :\n",
    "        \"\"\"\n",
    "        total_wait = 0\n",
    "\n",
    "        while True and total_wait < max_wait_sec:\n",
    "            resp = self.client.list_clusters()\n",
    "\n",
    "            this_cluster = None\n",
    "\n",
    "            # is this the cluster?\n",
    "            for c in resp['clusters']:\n",
    "                if clusterId == c['clusterId']:\n",
    "                    this_cluster = c\n",
    "\n",
    "            if this_cluster is None:\n",
    "                print(f\"clusterId:{clusterId} not found\")\n",
    "                return (None)\n",
    "\n",
    "            this_status = this_cluster['clusterStatus']['state']\n",
    "\n",
    "            if this_status.upper() != status.upper():\n",
    "                print(f\"Cluster status is {this_status}, waiting {sleep_sec} sec ...\")\n",
    "                time.sleep(sleep_sec)\n",
    "                total_wait = total_wait + sleep_sec\n",
    "                continue\n",
    "            else:\n",
    "                return (this_cluster)\n",
    "\n",
    "    def get_working_location(self, locationType='SAGEMAKER'):\n",
    "        resp = None\n",
    "        location = self.client.get_working_location(locationType=locationType)\n",
    "\n",
    "        if 's3Uri' in location:\n",
    "            resp = location['s3Uri']\n",
    "\n",
    "        return (resp)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# %load ../Utilities/finspace_spark.py\n",
    "import datetime\n",
    "import time\n",
    "import boto3\n",
    "from botocore.config import Config\n",
    "\n",
    "# FinSpace class with Spark bindings\n",
    "\n",
    "class SparkFinSpace(FinSpace):\n",
    "    import pyspark\n",
    "    def __init__(\n",
    "        self, \n",
    "        spark: pyspark.sql.session.SparkSession = None,\n",
    "        config = Config(retries = {'max_attempts': 0, 'mode': 'standard'}),\n",
    "        dev_overrides: dict = None\n",
    "    ):\n",
    "        FinSpace.__init__(self, config=config, dev_overrides=dev_overrides)\n",
    "        self.spark = spark # used on Spark cluster for reading views, creating changesets from DataFrames\n",
    "        \n",
    "    def upload_dataframe(self, data_frame: pyspark.sql.dataframe.DataFrame):\n",
    "        resp = self.client.get_user_ingestion_info()\n",
    "        upload_location = resp['ingestionPath']\n",
    "#        data_frame.write.option('header', 'true').csv(upload_location)\n",
    "        data_frame.write.parquet(upload_location)\n",
    "        return upload_location\n",
    "    \n",
    "    def ingest_dataframe(self, data_frame: pyspark.sql.dataframe.DataFrame, dataset_id: str, change_type: str, wait_for_completion=True):\n",
    "        print(\"Uploading data...\")\n",
    "        upload_location = self.upload_dataframe(data_frame)\n",
    "        \n",
    "        print(\"Data upload finished. Ingesting data...\")\n",
    "        \n",
    "        return self.ingest_from_s3(upload_location, dataset_id, change_type, wait_for_completion, format_type='parquet', format_params={})\n",
    "    \n",
    "    def read_view_as_spark(\n",
    "        self,\n",
    "        dataset_id: str,\n",
    "        view_id: str\n",
    "        ):\n",
    "        # TODO: switch to DescribeMatz when available in HFS\n",
    "        views = self.list_views(dataset_id=dataset_id, max_results=50)\n",
    "        filtered = [v for v in views if v['id'] == view_id]\n",
    "\n",
    "        if len(filtered) == 0:\n",
    "            raise Exception('No such view found')\n",
    "        if len(filtered) > 1:\n",
    "            raise Exception('Internal Server error')\n",
    "        view = filtered[0]\n",
    "        \n",
    "        # 0. Ensure view is ready to be read\n",
    "        if (view['status'] != 'SUCCESS'): \n",
    "            status = view['status'] \n",
    "            print(f'view run status is not ready: {status}. Returning empty.')\n",
    "            return\n",
    "\n",
    "        glue_db_name = view['destinationTypeProperties']['databaseName']\n",
    "        glue_table_name = view['destinationTypeProperties']['tableName']\n",
    "        \n",
    "        # Query Glue table directly with catalog function of spark\n",
    "        return self.spark.table(f\"`{glue_db_name}`.`{glue_table_name}`\")\n",
    "    \n",
    "    def get_schema_from_spark(self, data_frame: pyspark.sql.dataframe.DataFrame):\n",
    "        from pyspark.sql.types import StructType\n",
    "\n",
    "        # for translation to FinSpace's schema\n",
    "        # 'STRING'|'CHAR'|'INTEGER'|'TINYINT'|'SMALLINT'|'BIGINT'|'FLOAT'|'DOUBLE'|'DATE'|'DATETIME'|'BOOLEAN'|'BINARY'\n",
    "        DoubleType    = \"DOUBLE\"\n",
    "        FloatType     = \"FLOAT\"\n",
    "        DateType      = \"DATE\"\n",
    "        StringType    = \"STRING\"\n",
    "        IntegerType   = \"INTEGER\"\n",
    "        LongType      = \"BIGINT\"\n",
    "        BooleanType   = \"BOOLEAN\"\n",
    "        TimestampType = \"DATETIME\"\n",
    "        \n",
    "        hab_columns = []\n",
    "\n",
    "        items = [i for i in data_frame.schema] \n",
    "\n",
    "        switcher = {\n",
    "            \"BinaryType\"    : StringType,\n",
    "            \"BooleanType\"   : BooleanType,\n",
    "            \"ByteType\"      : IntegerType,\n",
    "            \"DateType\"      : DateType,\n",
    "            \"DoubleType\"    : FloatType,\n",
    "            \"IntegerType\"   : IntegerType,\n",
    "            \"LongType\"      : IntegerType,\n",
    "            \"NullType\"      : StringType,\n",
    "            \"ShortType\"     : IntegerType,\n",
    "            \"StringType\"    : StringType,\n",
    "            \"TimestampType\" : TimestampType,\n",
    "        }\n",
    "\n",
    "        \n",
    "        for i in items:\n",
    "#            print( f\"name: {i.name} type: {i.dataType}\" )\n",
    "\n",
    "            habType = switcher.get( str(i.dataType), StringType)\n",
    "\n",
    "            hab_columns.append({\n",
    "                \"dataType\"    : habType, \n",
    "                \"name\"        : i.name,\n",
    "                \"description\" : \"\"\n",
    "            })\n",
    "\n",
    "        return( hab_columns )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FinSpace Environment\n",
    "\n",
    "Please provide values from your AWS account (S3 location) and your FinSpace environment. The group ID is from the user group you want to associate the dataset to, this example will grant all permissions to the group for this dataset it creates.\n",
    "\n",
    "## Getting the Group ID\n",
    "\n",
    "Navigate to the Analyst group (gear menu, users and groups, select group named Analyst). The URL is of this pattern:  \n",
    "http://**ENVIRONMEN_ID**.**REGION**.amazonfinspace.com/userGroup/**GROUP_ID**  \n",
    "\n",
    "Copy the string for GroupID into the **group_id** variable assignment below\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "service_name: finspace-data\n",
      "endpoint: None\n",
      "region_name: us-east-1"
     ]
    }
   ],
   "source": [
    "\n",
    "# S3 bucket where you exported ESG data from Data Exchance \n",
    "root_folder = 's3://myesgfulldata'\n",
    "\n",
    "# dataset_id, if None will create, if not None, this update will be an append usig bucket contents\n",
    "dataset_id = None\n",
    "\n",
    "# User Group to grant access to the dataset\n",
    "group_id = 'mrwdlitspUyKQJkd93XpgA'\n",
    "\n",
    "\n",
    "# initialize the FinSpace helper object\n",
    "finspace = FinSpace()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset Definitions\n",
    "Capture the dataset's name, description, schema, attribute set, attribute set values, permissions to assign to the permission group."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Name for the dataset\n",
    "name = \"ESG News DATE\"\n",
    "\n",
    "# description for the dataset\n",
    "description = \"\"\"This trial dataset is industrial-scale NLP applied to thousands of news sources to develop in-depth, real-time scoring at the company level on ESG issues\n",
    "\"\"\"\n",
    "\n",
    "# this is the attribute set to use, will search for it in system, this name assumes the Capital Markets Sample Data Bundle was installed\n",
    "att_name = \"Sample Data Attribute Set\"\n",
    "\n",
    "# Attributes to associate, based on the definition of the attribute set\n",
    "att_values = [\n",
    "    { 'field' : 'AssetClass', 'type' : 'TAXONOMY', 'values' : [ 'Equity', 'CommonStocks'] },\n",
    "    { 'field' : 'DataType',   'type' : 'TAXONOMY', 'values' : [ 'News', 'ESG' ] },\n",
    "    { 'field' : 'Source',     'type' : 'TAXONOMY', 'values' : [ 'Amenity Analytics'] },\n",
    "    { 'field' : 'EventType',  'type' : 'TAXONOMY', 'values' : [ 'ClosingPrice' ] },\n",
    "    { 'field' : 'SampleData', 'type' : 'TAXONOMY', 'values' : [ ] }\n",
    "]\n",
    "\n",
    "# Permissions to grant the above group for the created dataset\n",
    "basicPermissions = [\n",
    "    \"ViewDatasetDetails\",\n",
    "    \"ReadDatasetData\",\n",
    "    \"AddDatasetData\",\n",
    "    \"CreateSnapshot\",\n",
    "    \"EditDatasetMetadata\",\n",
    "    \"ManageDatasetPermissions\",\n",
    "    \"DeleteDataset\"\n",
    "]\n",
    "\n",
    "# All datasets have ownership\n",
    "basicOwnerInfo = {\n",
    "    \"phoneNumber\" : \"12125551000\",\n",
    "    \"email\"       : \"jdoe@amazon.com\",\n",
    "    \"name\"        : \"Jane Doe\"\n",
    "}\n",
    "\n",
    "# schema of the dataset\n",
    "schema = {\n",
    "    'primaryKeyColumns': [],\n",
    "    'columns' : [\n",
    "\n",
    "        {'dataType': 'DATE',  'name': 'date', 'description': 'A date for the aggregate ESG scores of a company'},\n",
    "        {'dataType': 'STRING',  'name': 'symbologyId', 'description': 'FactSet unique identifier for a company. Helps in tracking a company in instances where the company ticker has changed'},\n",
    "        {'dataType': 'STRING',  'name': 'companyName', 'description': 'The company name'},\n",
    "        {'dataType': 'STRING',  'name': 'ticker', 'description': 'The company ticker symbol'},\n",
    "        {'dataType': 'STRING',  'name': 'region', 'description': 'The region of the company’s ticker'},\n",
    "        {'dataType': 'DOUBLE',  'name': 'totalPositiveCount', 'description': 'Total count of positive extractions for a date'},\n",
    "        {'dataType': 'DOUBLE',  'name': 'totalNegativeCount', 'description': 'Total count of negative extractions for a date'},\n",
    "        {'dataType': 'DOUBLE',  'name': 'totalCountDailyScore', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'totalWeightedPositiveCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'totalWeightedNegativeCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'totalWeightedCountDailyScore', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdGovernancePositiveCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdGovernanceNegativeCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdGovernanceCountDailyScore', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdGovernanceWeightedPositiveCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdGovernanceWeightedNegativeCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdGovernanceWeightedCountDailyScore', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdSocialPositiveCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdSocialNegativeCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdSocialCountDailyScore', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdSocialWeightedPositiveCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdSocialWeightedNegativeCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdSocialWeightedCountDailyScore', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdGeneralPositiveCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdGeneralNegativeCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdGeneralCountDailyScore', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdGeneralWeightedPositiveCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdGeneralWeightedNegativeCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdGeneralWeightedCountDailyScore', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdEnvironmentalPositiveCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdEnvironmentalNegativeCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdEnvironmentalCountDailyScore', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdEnvironmentalWeightedPositiveCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdEnvironmentalWeightedNegativeCount', 'description': ''},\n",
    "        {'dataType': 'DOUBLE',  'name': 'kdEnvironmentalWeightedCountDailyScore', 'description': ''}\n",
    "        \n",
    "        \n",
    "    ]\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset ID: nq4xqh0"
     ]
    }
   ],
   "source": [
    "# call FinSpace to create the dataset if no ID was assigned\n",
    "# if an ID was assigned, will not create a dataset but will simply add data to it\n",
    "if dataset_id is None:\n",
    "    dataset_id = finspace.create_dataset(\n",
    "        name = name,\n",
    "        description = description,\n",
    "        permission_group_id = group_id,\n",
    "        dataset_permissions = basicPermissions,\n",
    "        kind = \"TABULAR\",\n",
    "        owner_info = basicOwnerInfo,\n",
    "        schema = schema\n",
    "    )\n",
    "    time.sleep(5)\n",
    "\n",
    "print(f'Dataset ID: {dataset_id}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ingesting from: s3://myesgfulldata\n",
      "Changeset status is still PENDING, waiting 10 sec ...\n",
      "Changeset status is still PENDING, waiting 10 sec ...\n",
      "Changeset complete"
     ]
    }
   ],
   "source": [
    "# use pandas to generate a range of dates between start and end\n",
    "\n",
    "s3_source = f'{root_folder}'\n",
    "print(f'Ingesting from: {s3_source}')\n",
    "\n",
    "try:\n",
    "    changeset_id = finspace.ingest_from_s3(s3_location=s3_source,\n",
    "                                       dataset_id=dataset_id,\n",
    "                                       change_type='APPEND',\n",
    "                                       wait_for_completion=True,\n",
    "                                       format_type='CSV')\n",
    "except Exception as e:\n",
    "    print(e)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Changeset ID: kr2cz8lonXLXncyh8NUrbw"
     ]
    }
   ],
   "source": [
    "print(f'Changeset ID: {changeset_id}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Created autoupdate_snapshot_id = 6L2czXfhsRggiFFlX1132A\n",
      "dataset_id = nq4xqh0"
     ]
    }
   ],
   "source": [
    "# Create an auto-updating View if one does not exist\n",
    "existing_snapshots = finspace.list_views(dataset_id = dataset_id, max_results=100)\n",
    "\n",
    "autoupdate_snapshot_id = None\n",
    "\n",
    "# does one exist?\n",
    "for ss in existing_snapshots:\n",
    "    if ss['autoUpdate'] == True:\n",
    "        autoupdate_snapshot_id = ss['id']\n",
    "\n",
    "# if no auto-updating view, create it\n",
    "if autoupdate_snapshot_id is None:\n",
    "    autoupdate_snapshot_id = finspace.create_auto_update_view(\n",
    "        dataset_id = dataset_id,\n",
    "        destination_type = \"GLUE_TABLE\",\n",
    "        partition_columns = [],\n",
    "        sort_columns = [],\n",
    "        wait_for_completion = False)\n",
    "    print( f\"Created autoupdate_snapshot_id = {autoupdate_snapshot_id}\" )\n",
    "else:\n",
    "    print( f\"Exists: autoupdate_snapshot_id = {autoupdate_snapshot_id}\" )\n",
    "\n",
    "print( f\"dataset_id = {dataset_id}\" )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Associating values to attribute set: Sample Data Attribute Set\n",
      "Nothing to disassociate\n",
      "{'ResponseMetadata': {'RequestId': '35ca75a2-b13c-45c5-ae27-1fdecf80c921', 'HTTPStatusCode': 200, 'HTTPHeaders': {'content-type': 'application/x-amz-json-1.1', 'content-length': '0', 'connection': 'keep-alive', 'date': 'Thu, 12 Aug 2021 09:09:51 GMT', 'x-amzn-requestid': '35ca75a2-b13c-45c5-ae27-1fdecf80c921', 'x-amz-apigw-id': 'D8jG8GlbIAMFwUA=', 'x-amzn-trace-id': 'Root=1-6114e55f-5f82ab9e7f13b442701933f1', 'x-cache': 'Miss from cloudfront', 'via': '1.1 5a8b742274bb7bf8d0871df4a4c7081f.cloudfront.net (CloudFront)', 'x-amz-cf-pop': 'IAD66-C2', 'x-amz-cf-id': 'j3quw8eOLnRLu-pqn_dOBAB8e-H98nOyUCp4qkak6ZvXExykfnFf1g=='}, 'RetryAttempts': 0}}"
     ]
    }
   ],
   "source": [
    "# Associate an attribute set and fill its values\n",
    "# if values where previously populated for this attribute set, this will overwrite them\n",
    "\n",
    "if (att_name is not None and att_values is not None):\n",
    "    print(f\"Associating values to attribute set: {att_name}\")\n",
    "    finspace.associate_attribute_set(att_name=att_name, att_values=att_values, dataset_id=dataset_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Last Run: 2021-08-12 09:09:52.863463"
     ]
    }
   ],
   "source": [
    "import datetime\n",
    "print( f\"Last Run: {datetime.datetime.now()}\" )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "FinSpace PySpark (finspace-sparkmagic-ffd02/latest)",
   "language": "python",
   "name": "pysparkkernel__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:998492697549:image/finspace-sparkmagic-ffd02"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "python",
    "version": 3
   },
   "mimetype": "text/x-python",
   "name": "pyspark",
   "pygments_lexer": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
