Introduction

In this introduction, we’ll provide a step-by-step guide to training models with AWS Sagemaker using the sagemaker R package.

We are going to train and tune an xgboost regression model on the sagemaker::abalone dataset, analyze the hyperparameters, and make new predictions.

Tuning

The tuning interface is similar to the caret package. We’ll

  1. choose a model

  2. define a hyperparameter grid

  3. set the training and validation data

Dataset

I’ll be building a regression model on the built-in abalone dataset, taken from UCI dataset database.

The built-in hyperparameter tuning methods with AWS Sagemaker requires a train/validation split. Cross-validation is not supported out of the box.

We can quickly split the data with rsample:

The training data needs to be uploaded to an S3 bucket that AWS Sagemaker has read/write permission to. For the typical AWS Sagemaker role, this could be any bucket with sagemaker included in the name.

We’ll use the sagemaker::write_s3 helper to upload tibbles or data.frames to S3 as a csv.

write_s3(analysis(abalone_split), s3(s3_bucket(), "abalone-train.csv"))
write_s3(assessment(abalone_split), s3(s3_bucket(), "abalone-test.csv"))

You can also set a default bucket with options(sagemaker.default.bucket = "bucket_name") for sagemaker::s3_bucket.

Then we’ll save the paths to use in tuning:

split <- s3_split(
  s3_train = s3(s3_bucket(), "abalone-train.csv"),
  s3_validation = s3(s3_bucket(), "abalone-test.csv")
)

Hyperparameters

Now we’ll define ranges to tune over:

Analysis

Training

We can also see the individual jobs logs, to track the difference between the train/validation set. This might be useful for advanced model tuning.

Note that tune$model_name is the name of the best model found during training.

job_logs %>%
  pivot_longer(`train:rmse`:`validation:rmse`) %>%
  ggplot(aes(iteration, value, color = name)) +
  geom_line()

Predictions

The AWS Sagemaker API supports two predictions modes: real-time endpoint and batch inference.

Real-time

Real-time opens a persistent web-endpoint for predictions. Deploying takes a few minutes.

Then make new predictions on tibbles or data.frames, using the standard predict generic.

pred <- predict(tune, sagemaker::abalone[1:100, -1])

Once deployed, the endpoint has a subsecond latency.

Make sure to delete the endpoint when you are done to avoid charges.

Batch

You can also make batch predictions from data saved in S3. The batch method will write the predictions as a csv in an S3 folder.

s3_output_path <- batch_predict(
  tune, 
  s3_input = s3(s3_bucket(), "abalone-inference.csv"),
  s3_output = s3(s3_bucket(), "abalone_predictions")
)

We can use the sagemaker::read_s3 method to easily read csv data from S3.