Update Google Prediction samples.
Reviewed in http://codereview.appspot.com/5696061/.
Index: samples/prediction/number.csv
===================================================================
deleted file mode 100644
diff --git a/samples/prediction/prediction.py b/samples/prediction/prediction.py
index 1227019..958d79c 100644
--- a/samples/prediction/prediction.py
+++ b/samples/prediction/prediction.py
@@ -22,7 +22,7 @@
the setup.sh script to load the sample data to Google Storage.
Usage:
- $ python prediction.py --object_name="bucket/object"
+ $ python prediction.py --object_name="bucket/object" --id="model_id"
You can also get help on all the command-line flags the program understands
by running:
@@ -34,53 +34,81 @@
$ python prediction.py --logging_level=DEBUG
"""
-__author__ = 'jcgregorio@google.com (Joe Gregorio)'
+__author__ = ('jcgregorio@google.com (Joe Gregorio), '
+ 'marccohen@google.com (Marc Cohen)')
import apiclient.errors
import gflags
import httplib2
import logging
+import os
import pprint
import sys
+import time
from apiclient.discovery import build
from oauth2client.file import Storage
from oauth2client.client import AccessTokenRefreshError
-from oauth2client.client import OAuth2WebServerFlow
+from oauth2client.client import flow_from_clientsecrets
from oauth2client.tools import run
FLAGS = gflags.FLAGS
-# Set up a Flow object to be used if we need to authenticate. This
-# sample uses OAuth 2.0, and we set up the OAuth2WebServerFlow with
-# the information it needs to authenticate. Note that it is called
-# the Web Server Flow, but it can also handle the flow for native
-# applications <http://code.google.com/apis/accounts/docs/OAuth2.html#IA>
-# The client_id client_secret are copied from the API Access tab on
-# the Google APIs Console <http://code.google.com/apis/console>. When
-# creating credentials for this application be sure to choose an Application
-# type of "Installed application".
-FLOW = OAuth2WebServerFlow(
- client_id='433807057907.apps.googleusercontent.com',
- client_secret='jigtZpMApkRxncxikFpR+SFg',
- scope='https://www.googleapis.com/auth/prediction',
- user_agent='prediction-cmdline-sample/1.0')
+# CLIENT_SECRETS, name of a file containing the OAuth 2.0 information for this
+# application, including client_id and client_secret, which are found
+# on the API Access tab on the Google APIs
+# Console <http://code.google.com/apis/console>
+CLIENT_SECRETS = 'client_secrets.json'
+
+# Helpful message to display in the browser if the CLIENT_SECRETS file
+# is missing.
+MISSING_CLIENT_SECRETS_MESSAGE = """
+WARNING: Please configure OAuth 2.0
+
+To make this sample run you will need to populate the client_secrets.json file
+found at:
+
+ %s
+
+with information from the APIs Console <https://code.google.com/apis/console>.
+
+""" % os.path.join(os.path.dirname(__file__), CLIENT_SECRETS)
+
+# Set up a Flow object to be used if we need to authenticate.
+FLOW = flow_from_clientsecrets(CLIENT_SECRETS,
+ scope='https://www.googleapis.com/auth/prediction',
+ message=MISSING_CLIENT_SECRETS_MESSAGE)
# The gflags module makes defining command-line options easy for
# applications. Run this program with the '--help' argument to see
# all the flags that it understands.
gflags.DEFINE_enum('logging_level', 'ERROR',
- ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
- 'Set the level of logging detail.')
+ ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
+ 'Set the level of logging detail.')
gflags.DEFINE_string('object_name',
None,
'Full Google Storage path of csv data (ex bucket/object)')
-
gflags.MarkFlagAsRequired('object_name')
+gflags.DEFINE_string('id',
+ None,
+ 'Model Id of your choosing to name trained model')
+gflags.MarkFlagAsRequired('id')
+
+# Time to wait (in seconds) between successive checks of training status.
+SLEEP_TIME = 10
+
+def print_header(line):
+ '''Format and print header block sized to length of line'''
+ header_str = '='
+ header_line = header_str * len(line)
+ print '\n' + header_line
+ print line
+ print header_line
+
def main(argv):
- # Let the gflags module process the command-line arguments
+ # Let the gflags module process the command-line arguments.
try:
argv = FLAGS(argv)
except gflags.FlagsError, e:
@@ -103,40 +131,61 @@
http = httplib2.Http()
http = credentials.authorize(http)
- service = build("prediction", "v1.3", http=http)
-
try:
- # Start training on a data set
- train = service.training()
- body = {'id': FLAGS.object_name}
- start = train.insert(body=body).execute()
+ # Get access to the Prediction API.
+ service = build("prediction", "v1.5", http=http)
+ papi = service.trainedmodels()
- print 'Started training'
+ # List models.
+ print_header('Fetching list of first ten models')
+ result = papi.list(maxResults=10).execute()
+ print 'List results:'
+ pprint.pprint(result)
+
+ # Start training request on a data set.
+ print_header('Submitting model training request')
+ body = {'id': FLAGS.id, 'storageDataLocation': FLAGS.object_name}
+ start = papi.insert(body=body).execute()
+ print 'Training results:'
pprint.pprint(start)
-
- import time
- # Wait for the training to complete
+
+ # Wait for the training to complete.
+ print_header('Waiting for training to complete')
while True:
- try:
- # We check the training job is completed. If it is not it will return an error code.
- status = train.get(data=FLAGS.object_name).execute()
- # Job has completed.
- pprint.pprint(status)
+ status = papi.get(id=FLAGS.id).execute()
+ state = status['trainingStatus']
+ print 'Training state: ' + state
+ if state == 'DONE':
break
- except apiclient.errors.HttpError as error:
- # Training job not yet completed.
- print 'Waiting for training to complete.'
- time.sleep(10)
+ elif state == 'RUNNING':
+ time.sleep(SLEEP_TIME)
+ continue
+ else:
+ raise Exception('Training Error: ' + state)
+
+ # Job has completed.
+ print 'Training completed:'
+ pprint.pprint(status)
+ break
- print 'Training is complete'
+ # Describe model.
+ print_header('Fetching model description')
+ result = papi.analyze(id=FLAGS.id).execute()
+ print 'Analyze results:'
+ pprint.pprint(result)
- # Now make a prediction using that training
+ # Make a prediction using the newly trained model.
+ print_header('Making a prediction')
body = {'input': {'csvInstance': ["mucho bueno"]}}
- prediction = train.predict(body=body, data=FLAGS.object_name).execute()
- print 'The prediction is:'
- pprint.pprint(prediction)
+ result = papi.predict(body=body, id=FLAGS.id).execute()
+ print 'Prediction results...'
+ pprint.pprint(result)
+ # Delete model.
+ print_header('Deleting model')
+ result = papi.delete(id=FLAGS.id).execute()
+ print 'Model deleted.'
except AccessTokenRefreshError:
print ("The credentials have been revoked or expired, please re-run"