From 08239a2181993426412cc30fa09ef47438d88f55 Mon Sep 17 00:00:00 2001 From: Xiaofei Wang <xiaofei@student.unimelb.edu.au> Date: Wed, 11 Mar 2020 21:04:04 +1100 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=B0=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AZUL Learn Opponent/learn_opponent.py | 48 +++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 AZUL Learn Opponent/learn_opponent.py diff --git a/AZUL Learn Opponent/learn_opponent.py b/AZUL Learn Opponent/learn_opponent.py new file mode 100644 index 0000000..e3ff9df --- /dev/null +++ b/AZUL Learn Opponent/learn_opponent.py @@ -0,0 +1,48 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import random + +from tensorflow import keras + +# Helper libraries +import numpy as np +import matplotlib.pyplot as plt + + +class Net_model: + + + def train(self, x, y): + model = keras.models.Sequential([ + keras.layers.Flatten(input_shape=(150, 6)), + keras.layers.Dense(1024, activation='relu'), + # keras.layers.Dropout(0.2), + keras.layers.Dense(150, activation='softmax') + ]) + + model.compile(optimizer='adam', + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + train_num = len(x)//5*4 + print('tn', train_num) + DATA_X = np.array(x) + DATA_Y = np.array(y) + + TR_X, TE_X = DATA_X[:train_num], DATA_X[train_num:] + TR_Y, TE_Y = DATA_Y[:train_num], DATA_Y[train_num:] + + cost = model.fit(TR_X, TR_Y, epochs=50) + print(cost) + cost = model.evaluate(TE_X, TE_Y) + print(cost) + Y_pred = model.predict(TE_X) + + # print(TE_X[0]) + # print(np.argmax(Y_pred[0])) + model.save('naive_model.h5') + + def perdict(self, X): + model = keras.models.load_model('naive_model.h5') + Y_pred = model.predict(X) + return Y_pred + -- GitLab