XGBoost for Night Sky Classification

Debarshi Raj Basumatary
8 min readJul 24, 2023

This article will cover the basics of Decision Trees and XGBoost, and demonstrate how to implement the latter for classifying celestial objects in the night sky as either a Galaxy, Star, or Quasar.

Photo by Nathan Anderson on Unsplash

Brief Intro to ML Algorithms

A decision tree is a widely used algorithm in machine learning. It is represented by a flowchart-like tree structure where an internal node represents a feature (or attribute), the branch represents a decision rule, and each leaf node represents an outcome. Decision trees are commonly used in classification and regression tasks and can efficiently handle large datasets.

However, on their own, they are not very accurate and are prone to overfitting.

One clever way to improve predictions is by combining multiple decision trees, aggregating their outputs, and predicting the final outcome based on the average (for regression) or majority voting (for classification). This is called Random Forest. However, this approach may struggle when the dataset is imbalanced, and the model becomes computationally too heavy.

Another approach is to use Gradient Boosted Trees.

The Gradient Boosted Tree (GBT) algorithm begins with an initial prediction and iteratively adds decision trees to correct errors and improve accuracy. The process starts with an initial prediction made using a base estimator (an initial decision tree). The residual (error) is calculated and sent to the next estimator (another decision tree) as the output label. The estimator trains on the datasets, but the output labels are the errors from the previous tree. This process continues until the desired level of accuracy is achieved.

It excels in tasks with high-dimensional data and delivers reliable results due to its ability to handle different loss functions and provide feature importance estimates.

XGBoost represents a next-generation evolution of Gradient Boosted Trees, building upon the same foundational principles while introducing several essential enhancements. Notably, it incorporates parallelization, regularization techniques, and a more efficient greedy algorithm for identifying optimal split points during tree construction. These improvements result in significantly faster training times and enhanced model generalization capabilities. Its exceptional accuracy and high-performance capabilities make it a go-to choice for many data scientists and researchers.

It was developed by Tianqi Chen, a data scientist and software engineer, during his Ph.D. at the University of Washington.

If you are interested in getting a deeper understanding of the math and theory behind XGBoost, then you should definitely visit Here.

The Dataset

The “Stellar Classification Dataset — SDSS17” by FEDESORIANO from Kaggle[Link], is made from “Sloan Digital Sky Survey” project and contains spectroscopic observations of celestial objects in the night sky. The dataset consists of 100,000 observations, each of which is described by 17 feature columns and 1 class column that identifies it as either a star, galaxy, or quasar.

Quasars, short for “quasi-stellar radio sources,” are extremely bright and distant astronomical objects found at the centres of galaxies. Quasars are not stars but rather the active cores of distant galaxies, powered by supermassive black holes.

  1. obj_ID = Object Identifier, the unique value that identifies the object in the image catalog used by the CAS
  2. alpha = Right Ascension angle (at J2000 epoch)
  3. delta = Declination angle (at J2000 epoch)
  4. u = Ultraviolet filter in the photometric system
  5. g = Green filter in the photometric system
  6. r = Red filter in the photometric system
  7. i = Near Infrared filter in the photometric system
  8. z = Infrared filter in the photometric system
  9. run_ID = Run Number used to identify the specific scan
  10. rereun_ID = Rerun Number to specify how the image was processed
  11. cam_col = Camera column to identify the scanline within the run
  12. field_ID = Field number to identify each field
  13. spec_obj_ID = Unique ID used for optical spectroscopic objects (this means that 2 different observations with the same spec_obj_ID must share the output class)
  14. class = object class (galaxy, star or quasar object)
  15. redshift = redshift value based on the increase in wavelength
  16. plate = plate ID, identifies each plate in SDSS
  17. MJD = Modified Julian Date, used to indicate when a given piece of SDSS data was taken
  18. fiber_ID = fiber ID that identifies the fiber that pointed the light at the focal plane in each observation

Implementation

Requirements:
1. Knowledge of Python, Pandas, Numpy, Matplotlib and Seaborn.
2. Anaconda

Step 1: Install XGBoost

conda install -c conda-forge py-xgboost

Step 2: Import all the required libraries

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, accuracy_score, confusion_matrix
from xgboost import XGBClassifier

XGBClassifier is a scikit-learn wrapper over the Xgboost library, it allows us to directly use xgboost with scikit-learn and pandas apis.

Step 3: Import & Validate the Dataset

stellar_data = pd.read_csv('star_classification.csv')
stellar_data.head()

for training the model we will not need “obj_ID”

stellar_data.info()

There are no Null values and all the columns are either int or float, no categorical column except for output class column.

stellar_data['class'].value_counts()

There are only three classes Galaxy, Star and Quasar.

Step 4: Visualise the Data

Below code will loop over all the columns, except for the ID, and plot a histogram. I have only shown 2, for the rest try yourself.

columns = stellar_data.columns
sns.set_theme(style="darkgrid")
for col in columns:
if col != 'obj_ID':
sns.displot(stellar_data, x=col)

You can analyse the plots for exploratory analysis and derive valuable insights but I am not a Astronomer or Astrophysicist, so I leave the rest to you.

Step 5: Prepare the data

The below code moves the class column to the last in the dataframe.

label_col = stellar_data['class']
del stellar_data['class']
stellar_data['class'] = label_col

Now we separate into X and y ,and then into train and test set

X = stellar_data.iloc[:,:-1]
y = stellar_data.iloc[:,-1]
X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.2, random_state=42)

Step 6: Training the Model

model = XGBClassifier(n_jobs=4,n_estimators=100,learning_rate=1)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
accuracy_score(y_test, y_pred)

Ignore the hyperparameters for now.

The model has an accuracy of 97.8 % not bad. But we can go further.

Step 7: Tuning Hyperparameters

XGboost provides a long list of hyperparameters but for our implementation we will focus on 3.

  1. learning_rate (eta):Controls the step size at each iteration during gradient boosting. Lower values make the model more robust but require more boosting rounds. Typical Range:[0 to 1]
  2. n_estimators: The number of boosting rounds (decision trees) to build. More trees can improve performance but increase computation time. Typical Range:[50 to 5000]
  3. max_depth:Maximum depth of each decision tree. Controls the complexity of the model and helps prevent overfitting. Typical Range:[3 to 10]

Higher the n_estimators lower the learning_rate should be. High max_depth captures more insights & feature interactions but also increases overfitting.

Let’s now fit the model with new Hyperparameters

model = XGBClassifier(n_jobs=4,n_estimators=300,learning_rate=0.1,max_depth=5)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
accuracy_score(y_test, y_pred)

n_jobs refers to the number of threads spawn to run the model. Part of parallelisation optimisation.

The accuracy has improve to 97.94% from 97.86%. We can experiment the model with other parameters to further improve the result, I leave the experimentation to you. Instead of manually tinkering around you can use Scikit-learn’s GridsearchCV and RandomSearchCV to find the best parameters but they consume large amount of resources to run.

# from sklearn.model_selection import GridSearchCV
# parameters = {
# 'learning_rate': [0.1, 0.01, 0.001], # Step size shrinkage used in update to prevent overfitting
# 'max_depth': [3, 5, 7], # Maximum depth of a tree
# 'n_estimators': [100, 200, 300], # Number of boosting rounds
# 'subsample': [0.8, 0.9], # Subsample ratio of the training instances
# 'colsample_bytree': [0.8, 0.9], # Subsample ratio of columns when constructing each tree
# 'gamma': [0, 0.1, 0.2], # Minimum loss reduction required to make a further partition on a leaf node
# 'min_child_weight': [1, 3, 5] # Minimum sum of instance weight (hessian) needed in a child
# }
# # Create GridSearchCV with the XGBoost classifier and parameters
# grid_search = GridSearchCV(estimator=XGBClassifier(), param_grid=parameters, scoring='accuracy', cv=5,n_jobs=5,verbose=2)

# # Perform the grid search on the training data
# grid_search.fit(X_train, y_train)

# # Access the best parameters found by GridSearchCV
# best_params = grid_search.best_params_

# print("Best Parameters:", best_params)

Step 8: Metrics of our model

For Classification Accuracy score is not enough, specially with unbalanced dataset(instance of one class is more than other class by a huge margin). We should also check for Precision(proportion of true positive predictions among all positive predictions) and Recall( proportion of true positive predictions among all actual positive instances).

In fraud detection, precision is more important because you want to minimize false positives. False positives in this context would mean flagging a transaction as fraudulent when it’s actually legitimate, which could inconvenience customers and lead to a loss of trust.

On the other hand, in medicine, especially in disease detection or diagnosis, recall is more important. It’s crucial to capture all positive cases (e.g., detecting a disease) even if it means having a higher number of false positives. Missing a true positive (false negative) in medicine could have serious consequences for the patient’s health and well-being.

Let’s built a confusion matrix for checking the true positives and true negatives of the model’s prediction.

cm = confusion_matrix(y_test, y_pred, labels=model.classes_)
fig, ax = plt.subplots(figsize=(9, 6))
sns.heatmap(cm,annot=True,fmt="",linewidth=.5, cmap="mako",xticklabels=model.classes_,yticklabels=model.classes_)
ax.set(xlabel="Predicted", ylabel="True")
ax.xaxis.tick_top()
plt.yticks(rotation=0)
plt.show()

As expected our model, has excellent accuracy but some true Quasars are labelled as Galaxy and a few Galaxy are labelled as Quasars.

Let’s check for Precision and Recall score for each class

from sklearn.metrics import precision_score, recall_score, average_precision_score

precision = precision_score(y_test, y_pred, average=None)
recall = recall_score(y_test, y_pred, average=None)


# Print the results for each class
for class_idx in range(len(precision)):
print(f"Class {model.classes_[class_idx]}:")
print(f"Precision: {precision[class_idx]}")
print(f"Recall: {recall[class_idx]}")
print("--------------")

From the above results, it is evident that the model performs exceptionally well for Galaxy and Star classification. However, when it comes to Quasar classification, the model’s performance is comparatively less accurate.

So, if you care about Galaxies and Stars go for this model, and if you care about Quasars maybe tinker the model’s parameters or find more dataset.

Conclusion

This article explains the basic theory of Decision Trees, Gradient Boosted Trees, and Xgboost. Furthermore, it includes an implementation of the Xgboost classifier for classifying the night sky, with the utilization of Recall/Precision & Confusion Matrix evaluation metrics. Other metrics such as ROC, AUC, and f1 scores are also available but not implemented here. It’s worth noting that Xgboost also provides other classes for regression tasks.

WRITER at MLearning.ai // EEG AI Prediction // Animate Midjourney

--

--

Debarshi Raj Basumatary

Software Engineer | Data Science & Machine Learning enthusiast.