Write a code to remove duplicate records based on a specific column in PySpark, keeping only rows with unique values for that column?

Write a code to remove duplicate records based on a specific column in PySpark, keeping only rows with unique values for that column?

Example: 

				
					Original 
+----------+------------+----------+------+
|EmployeeID|EmployeeName|Department|Salary|
+----------+------------+----------+------+
| 1| John Doe| Finance| 55000|
| 2| Jane Smith| IT| 75000|
| 3| Sam Brown| HR| 55000|
| 4| Emily Davis| IT| 80000|
+----------+------------+----------+------+

result
+----------+------------+----------+------+
|EmployeeID|EmployeeName|Department|Salary|
+----------+------------+----------+------+
| 2| Jane Smith| IT| 75000|
| 4| Emily Davis| IT| 80000|
+----------+------------+----------+------+
				
			

Python Code:

				
					from pyspark.sql import SparkSession
from pyspark.sql import functions as F

# Initialize a Spark session
spark = SparkSession.builder.appName("EmployeeDF").getOrCreate()

# Sample employee data with duplicate Salary
employee_data = [
    (1, "John Doe", "Finance", 55000),
    (2, "Jane Smith", "IT", 75000),
    (3, "Sam Brown", "HR", 55000),
    (4, "Emily Davis", "IT", 80000)
]

# Define the schema and create the DataFrame
schema = ["EmployeeID", "EmployeeName", "Department", "Salary"]
employee_df = spark.createDataFrame(employee_data, schema)

# Show the original DataFrame
print("Original DataFrame:")
employee_df.show()

# Step 1: Group by salary and count occurrences
salary_counts = employee_df.groupBy("Salary").count()

# Step 2: Filter out salaries that have more than 1 occurrence (i.e., duplicates)
unique_salaries = salary_counts.filter(F.col("count") == 1).select("Salary")

# Step 3: Join back with the original DataFrame to get rows with unique salaries
result_df = employee_df.join(unique_salaries, on="Salary", how="inner")

# Show the resulting DataFrame with only unique salary records
print("DataFrame after removing duplicate salary records:")
result_df.show()

				
			

Explanation:

  1. Group by Salary and count: We first group the data by the Salary column and count how many times each salary appears.
  2. Filter: We filter to keep only the salaries that appear once (count == 1), which removes duplicates.
  3. Join: Finally, we join back to the original DataFrame to keep only the rows with unique salaries.