import sys
from datetime import datetime

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType

if __name__ == "__main__":

    print(len(sys.argv))
    if (len(sys.argv) != 3):
        print("Usage: spark-etl [input-folder] [output-folder]")
        sys.exit(0)
    
    # INPUT REQUIRED - Data source and output
    NOAA_ISD = sys.argv[1]
    output_s3 = sys.argv[2]

    # Setting up the Spark Session
    spark = SparkSession\
            .builder\
            .appName("WindyCity")\
            .getOrCreate()

    # Load data for from the S3 path
    df = spark.read \
        .format("csv") \
        .option("header", "true") \
        .option("inferSchema", "true") \
        .load(NOAA_ISD)

    df.printSchema()
    print(df.show(10))

    # Selet specific columns and exclude missing values
    subset_df = df \
        .select("DATE", "NAME", "WND") \
        .withColumnRenamed("NAME", "location_name") \
        .filter(F.split(df.WND, ",")[3] != '9999') 

    subset_df.printSchema()
    print(subset_df.show(10))

    # Parse year and windspeed - scaled backed from the raw data
    wind_date_df = subset_df \
        .withColumn("wind_speed", F.split(subset_df.WND, ",")[3].cast(DoubleType())/10 ) \
        .withColumn("measurement_year", F.year(subset_df.DATE))\
        .select("location_name", "measurement_year", "wind_speed") 

    wind_date_df.printSchema()
    print(wind_date_df.show(10))

    # Find yearly min, avg and max wind speed for each location 
    agg_wind_df = wind_date_df \
                .groupBy("location_name","measurement_year") \
                .agg(F.min(wind_date_df.wind_speed).alias("min_wind_speed"),\
                     F.avg(wind_date_df.wind_speed).alias("avg_wind_speed"),\
                     F.max(wind_date_df.wind_speed).alias("max_wind_speed")\
                    )

    agg_wind_df.printSchema()
    print(agg_wind_df.show(10))

    # Writing the file output to your local S3 bucket
    current_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
    agg_wind_df.write.csv(f"{output_s3}/finhack-output/{current_time}/")

    spark.stop()
