Commit bd367b56 authored by Alejandro Riera's avatar Alejandro Riera

coursera clone

parent ca9c3860
import numpy as np
import h5py
from os import path
def load_dataset():
train_path = path.join(path.dirname(path.realpath(__file__)), 'datasets/train_catvnoncat.h5')
train_dataset = h5py.File(train_path, "r")
train_set_x_orig = np.array(train_dataset["train_set_x"][:]) # your train set features
train_set_y_orig = np.array(train_dataset["train_set_y"][:]) # your train set labels
test_path = path.join(path.dirname(path.realpath(__file__)), 'datasets/test_catvnoncat.h5')
test_dataset = h5py.File(test_path, "r")
test_set_x_orig = np.array(test_dataset["test_set_x"][:]) # your test set features
test_set_y_orig = np.array(test_dataset["test_set_y"][:]) # your test set labels
classes = np.array(test_dataset["list_classes"][:]) # the list of classes
train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0]))
test_set_y_orig = test_set_y_orig.reshape((1, test_set_y_orig.shape[0]))
return train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig, classes
torch
torchvision
numpy
matplotlib
h5py
scipy
# PIL
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment