Update Google Prediction samples.
Reviewed in http://codereview.appspot.com/5696061/.
Index: samples/prediction/number.csv
===================================================================
deleted file mode 100644
diff --git a/samples/prediction/number.csv b/samples/prediction/number.csv
deleted file mode 100644
index 5f63b61..0000000
--- a/samples/prediction/number.csv
+++ /dev/null
@@ -1,4 +0,0 @@
-4, 1
-9, 2
-16, 3
-
diff --git a/samples/prediction/number.pmml b/samples/prediction/number.pmml
deleted file mode 100644
index d81e317..0000000
--- a/samples/prediction/number.pmml
+++ /dev/null
@@ -1,23 +0,0 @@
-<PMML version="4.0" xsi:schemaLocation="http://www.dmg.org/PMML-4_0 http://www.dmg.org/v4-0/pmml-4-0.xsd" xmlns="http://www.dmg.org/PMML-4_0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
- <Header copyright="Copyright (c) 2011, Google Inc. All rights reserved.">
- <Application name="Google Prediction API Sample" version="1.4"/>
- </Header>
- <DataDictionary numberOfFields="1">
- <DataField name="X" optype="continuous" dataType="double"/>
- </DataDictionary>
- <TransformationDictionary>
- <DerivedField name="Y1" dataType="double" optype="continuous">
- <Constant>1.0</Constant>
- </DerivedField>
- <DerivedField name="Y2" dataType="double" optype="continuous">
- <FieldRef field="X"/>
- </DerivedField>
- <DerivedField name="Y3" dataType="double" optype="continuous">
- <Apply function="pow">
- <FieldRef field="X"/>
- <Constant>2.0</Constant>
- </Apply>
- </DerivedField>
- </TransformationDictionary>
-</PMML>
-
diff --git a/samples/prediction/prediction.py b/samples/prediction/prediction.py
index 1227019..958d79c 100644
--- a/samples/prediction/prediction.py
+++ b/samples/prediction/prediction.py
@@ -22,7 +22,7 @@
the setup.sh script to load the sample data to Google Storage.
Usage:
- $ python prediction.py --object_name="bucket/object"
+ $ python prediction.py --object_name="bucket/object" --id="model_id"
You can also get help on all the command-line flags the program understands
by running:
@@ -34,53 +34,81 @@
$ python prediction.py --logging_level=DEBUG
"""
-__author__ = 'jcgregorio@google.com (Joe Gregorio)'
+__author__ = ('jcgregorio@google.com (Joe Gregorio), '
+ 'marccohen@google.com (Marc Cohen)')
import apiclient.errors
import gflags
import httplib2
import logging
+import os
import pprint
import sys
+import time
from apiclient.discovery import build
from oauth2client.file import Storage
from oauth2client.client import AccessTokenRefreshError
-from oauth2client.client import OAuth2WebServerFlow
+from oauth2client.client import flow_from_clientsecrets
from oauth2client.tools import run
FLAGS = gflags.FLAGS
-# Set up a Flow object to be used if we need to authenticate. This
-# sample uses OAuth 2.0, and we set up the OAuth2WebServerFlow with
-# the information it needs to authenticate. Note that it is called
-# the Web Server Flow, but it can also handle the flow for native
-# applications <http://code.google.com/apis/accounts/docs/OAuth2.html#IA>
-# The client_id client_secret are copied from the API Access tab on
-# the Google APIs Console <http://code.google.com/apis/console>. When
-# creating credentials for this application be sure to choose an Application
-# type of "Installed application".
-FLOW = OAuth2WebServerFlow(
- client_id='433807057907.apps.googleusercontent.com',
- client_secret='jigtZpMApkRxncxikFpR+SFg',
- scope='https://www.googleapis.com/auth/prediction',
- user_agent='prediction-cmdline-sample/1.0')
+# CLIENT_SECRETS, name of a file containing the OAuth 2.0 information for this
+# application, including client_id and client_secret, which are found
+# on the API Access tab on the Google APIs
+# Console <http://code.google.com/apis/console>
+CLIENT_SECRETS = 'client_secrets.json'
+
+# Helpful message to display in the browser if the CLIENT_SECRETS file
+# is missing.
+MISSING_CLIENT_SECRETS_MESSAGE = """
+WARNING: Please configure OAuth 2.0
+
+To make this sample run you will need to populate the client_secrets.json file
+found at:
+
+ %s
+
+with information from the APIs Console <https://code.google.com/apis/console>.
+
+""" % os.path.join(os.path.dirname(__file__), CLIENT_SECRETS)
+
+# Set up a Flow object to be used if we need to authenticate.
+FLOW = flow_from_clientsecrets(CLIENT_SECRETS,
+ scope='https://www.googleapis.com/auth/prediction',
+ message=MISSING_CLIENT_SECRETS_MESSAGE)
# The gflags module makes defining command-line options easy for
# applications. Run this program with the '--help' argument to see
# all the flags that it understands.
gflags.DEFINE_enum('logging_level', 'ERROR',
- ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
- 'Set the level of logging detail.')
+ ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
+ 'Set the level of logging detail.')
gflags.DEFINE_string('object_name',
None,
'Full Google Storage path of csv data (ex bucket/object)')
-
gflags.MarkFlagAsRequired('object_name')
+gflags.DEFINE_string('id',
+ None,
+ 'Model Id of your choosing to name trained model')
+gflags.MarkFlagAsRequired('id')
+
+# Time to wait (in seconds) between successive checks of training status.
+SLEEP_TIME = 10
+
+def print_header(line):
+ '''Format and print header block sized to length of line'''
+ header_str = '='
+ header_line = header_str * len(line)
+ print '\n' + header_line
+ print line
+ print header_line
+
def main(argv):
- # Let the gflags module process the command-line arguments
+ # Let the gflags module process the command-line arguments.
try:
argv = FLAGS(argv)
except gflags.FlagsError, e:
@@ -103,40 +131,61 @@
http = httplib2.Http()
http = credentials.authorize(http)
- service = build("prediction", "v1.3", http=http)
-
try:
- # Start training on a data set
- train = service.training()
- body = {'id': FLAGS.object_name}
- start = train.insert(body=body).execute()
+ # Get access to the Prediction API.
+ service = build("prediction", "v1.5", http=http)
+ papi = service.trainedmodels()
- print 'Started training'
+ # List models.
+ print_header('Fetching list of first ten models')
+ result = papi.list(maxResults=10).execute()
+ print 'List results:'
+ pprint.pprint(result)
+
+ # Start training request on a data set.
+ print_header('Submitting model training request')
+ body = {'id': FLAGS.id, 'storageDataLocation': FLAGS.object_name}
+ start = papi.insert(body=body).execute()
+ print 'Training results:'
pprint.pprint(start)
-
- import time
- # Wait for the training to complete
+
+ # Wait for the training to complete.
+ print_header('Waiting for training to complete')
while True:
- try:
- # We check the training job is completed. If it is not it will return an error code.
- status = train.get(data=FLAGS.object_name).execute()
- # Job has completed.
- pprint.pprint(status)
+ status = papi.get(id=FLAGS.id).execute()
+ state = status['trainingStatus']
+ print 'Training state: ' + state
+ if state == 'DONE':
break
- except apiclient.errors.HttpError as error:
- # Training job not yet completed.
- print 'Waiting for training to complete.'
- time.sleep(10)
+ elif state == 'RUNNING':
+ time.sleep(SLEEP_TIME)
+ continue
+ else:
+ raise Exception('Training Error: ' + state)
+
+ # Job has completed.
+ print 'Training completed:'
+ pprint.pprint(status)
+ break
- print 'Training is complete'
+ # Describe model.
+ print_header('Fetching model description')
+ result = papi.analyze(id=FLAGS.id).execute()
+ print 'Analyze results:'
+ pprint.pprint(result)
- # Now make a prediction using that training
+ # Make a prediction using the newly trained model.
+ print_header('Making a prediction')
body = {'input': {'csvInstance': ["mucho bueno"]}}
- prediction = train.predict(body=body, data=FLAGS.object_name).execute()
- print 'The prediction is:'
- pprint.pprint(prediction)
+ result = papi.predict(body=body, id=FLAGS.id).execute()
+ print 'Prediction results...'
+ pprint.pprint(result)
+ # Delete model.
+ print_header('Deleting model')
+ result = papi.delete(id=FLAGS.id).execute()
+ print 'Model deleted.'
except AccessTokenRefreshError:
print ("The credentials have been revoked or expired, please re-run"
diff --git a/samples/prediction/prediction_language_id.py b/samples/prediction/prediction_language_id.py
deleted file mode 100644
index 9657b3c..0000000
--- a/samples/prediction/prediction_language_id.py
+++ /dev/null
@@ -1,167 +0,0 @@
-#!/usr/bin/python2.4
-# -*- coding: utf-8 -*-
-#
-# Copyright (C) 2010 Google Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Simple command-line sample for the Google Prediction API
-
-Command-line application that trains on your input data. This sample does
-the same thing as the Hello Prediction! example. You might want to run
-the setup.sh script to load the sample data to Google Storage.
-
-Usage:
- $ python prediction_language_id.py --model_id="foo"
- --data_file="bucket/object"
-
-You can also get help on all the command-line flags the program understands
-by running:
-
- $ python prediction_language_id.py --help
-
-To get detailed log output run:
-
- $ python prediction_language_id.py --logging_level=DEBUG
-"""
-
-__author__ = 'jcgregorio@google.com (Joe Gregorio)'
-
-from apiclient.discovery import build_from_document
-
-import apiclient.errors
-import gflags
-import httplib2
-import logging
-import os
-import pprint
-import sys
-
-from apiclient.discovery import build
-from oauth2client.file import Storage
-from oauth2client.client import AccessTokenRefreshError
-from oauth2client.client import flow_from_clientsecrets
-from oauth2client.tools import run
-
-FLAGS = gflags.FLAGS
-
-# CLIENT_SECRETS, name of a file containing the OAuth 2.0 information for this
-# application, including client_id and client_secret, which are found
-# on the API Access tab on the Google APIs
-# Console <http://code.google.com/apis/console>
-CLIENT_SECRETS = 'client_secrets.json'
-
-# Helpful message to display in the browser if the CLIENT_SECRETS file
-# is missing.
-MISSING_CLIENT_SECRETS_MESSAGE = """
-WARNING: Please configure OAuth 2.0
-
-To make this sample run you will need to populate the client_secrets.json file
-found at:
-
- %s
-
-with information from the APIs Console <https://code.google.com/apis/console>.
-
-""" % os.path.join(os.path.dirname(__file__), CLIENT_SECRETS)
-
-# Set up a Flow object to be used if we need to authenticate.
-FLOW = flow_from_clientsecrets(CLIENT_SECRETS,
- scope='https://www.googleapis.com/auth/prediction',
- message=MISSING_CLIENT_SECRETS_MESSAGE)
-
-# The gflags module makes defining command-line options easy for
-# applications. Run this program with the '--help' argument to see
-# all the flags that it understands.
-gflags.DEFINE_enum('logging_level', 'ERROR',
- ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
- 'Set the level of logging detail.')
-
-gflags.DEFINE_string('model_id',
- None,
- 'The unique name for the predictive model (ex foo)')
-
-gflags.DEFINE_string('data_file',
- None,
- 'Full Google Storage path of csv data (ex bucket/object)')
-
-gflags.MarkFlagAsRequired('model_id')
-gflags.MarkFlagAsRequired('data_file')
-
-def main(argv):
- # Let the gflags module process the command-line arguments
- try:
- argv = FLAGS(argv)
- except gflags.FlagsError, e:
- print '%s\\nUsage: %s ARGS\\n%s' % (e, argv[0], FLAGS)
- sys.exit(1)
-
- # Set the logging according to the command-line flag
- logging.getLogger().setLevel(getattr(logging, FLAGS.logging_level))
-
- # If the Credentials don't exist or are invalid run through the native client
- # flow. The Storage object will ensure that if successful the good
- # Credentials will get written back to a file.
- storage = Storage('prediction.dat')
- credentials = storage.get()
- if credentials is None or credentials.invalid:
- credentials = run(FLOW, storage)
-
- # Create an httplib2.Http object to handle our HTTP requests and authorize it
- # with our good Credentials.
- http = httplib2.Http()
- http = credentials.authorize(http)
-
- service = build("prediction", "v1.4", http=http)
-
- try:
-
- # Start training on a data set
- train = service.trainedmodels()
- body = {'id': FLAGS.model_id, 'storageDataLocation': FLAGS.data_file}
- start = train.insert(body=body).execute()
-
- print 'Started training'
- pprint.pprint(start)
-
- import time
- # Wait for the training to complete
- while True:
- try:
- # We check the training job is completed. If it is not it will return
- # an error code.
- status = train.get(id=FLAGS.model_id).execute()
- # Job has completed.
- pprint.pprint(status)
- break
- except apiclient.errors.HttpError as error:
- # Training job not yet completed.
- print 'Waiting for training to complete.'
- time.sleep(10)
-
- print 'Training is complete'
-
- # Now make a prediction using that training
- body = {'input': {'csvInstance': ["mucho bueno"]}}
- prediction = train.predict(body=body, id=FLAGS.model_id).execute()
- print 'The prediction is:'
- pprint.pprint(prediction)
-
-
- except AccessTokenRefreshError:
- print ("The credentials have been revoked or expired, please re-run"
- "the application to re-authorize")
-
-if __name__ == '__main__':
- main(sys.argv)
-
diff --git a/samples/prediction/prediction_number.py b/samples/prediction/prediction_number.py
deleted file mode 100644
index 536d6c8..0000000
--- a/samples/prediction/prediction_number.py
+++ /dev/null
@@ -1,175 +0,0 @@
-#!/usr/bin/python2.4
-# -*- coding: utf-8 -*-
-#
-# Copyright (C) 2010 Google Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Simple command-line sample for the Google Prediction API
-
-Command-line application that trains on your input data. This sample does
-the same thing as the Hello Prediction! example. You might want to run
-the setup.sh script to load both the sample data and the pmml file to
-Google Storage.
-
-Usage:
- $ python prediction_number.py --model_id="foo"
- --data_file="data_bucket/data_object" --pmml_file="pmml_bucket/pmml_object"
-
-You can also get help on all the command-line flags the program understands
-by running:
-
- $ python prediction_number.py --help
-
-To get detailed log output run:
-
- $ python prediction_number.py --logging_level=DEBUG
-"""
-
-__author__ = 'jcgregorio@google.com (Joe Gregorio)'
-
-from apiclient.discovery import build_from_document
-
-import apiclient.errors
-import gflags
-import httplib2
-import logging
-import os
-import pprint
-import sys
-
-from apiclient.discovery import build
-from oauth2client.file import Storage
-from oauth2client.client import AccessTokenRefreshError
-from oauth2client.client import flow_from_clientsecrets
-from oauth2client.tools import run
-
-FLAGS = gflags.FLAGS
-
-# CLIENT_SECRETS, name of a file containing the OAuth 2.0 information for this
-# application, including client_id and client_secret, which are found
-# on the API Access tab on the Google APIs
-# Console <http://code.google.com/apis/console>
-CLIENT_SECRETS = 'client_secrets.json'
-
-# Helpful message to display in the browser if the CLIENT_SECRETS file
-# is missing.
-MISSING_CLIENT_SECRETS_MESSAGE = """
-WARNING: Please configure OAuth 2.0
-
-To make this sample run you will need to populate the client_secrets.json file
-found at:
-
- %s
-
-with information from the APIs Console <https://code.google.com/apis/console>.
-
-""" % os.path.join(os.path.dirname(__file__), CLIENT_SECRETS)
-
-# Set up a Flow object to be used if we need to authenticate.
-FLOW = flow_from_clientsecrets(CLIENT_SECRETS,
- scope='https://www.googleapis.com/auth/prediction',
- message=MISSING_CLIENT_SECRETS_MESSAGE)
-
-# The gflags module makes defining command-line options easy for
-# applications. Run this program with the '--help' argument to see
-# all the flags that it understands.
-gflags.DEFINE_enum('logging_level', 'ERROR',
- ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
- 'Set the level of logging detail.')
-
-gflags.DEFINE_string('model_id',
- None,
- 'The unique name for the predictive model (ex foo)')
-
-gflags.DEFINE_string('data_file',
- None,
- 'Full Google Storage path of csv data (ex bucket/object)')
-
-gflags.DEFINE_string('pmml_file',
- None,
- 'Full Google Storage path of pmml for '
- 'preprocessing (ex bucket/object)')
-
-gflags.MarkFlagAsRequired('model_id')
-gflags.MarkFlagAsRequired('data_file')
-gflags.MarkFlagAsRequired('pmml_file')
-
-def main(argv):
- # Let the gflags module process the command-line arguments
- try:
- argv = FLAGS(argv)
- except gflags.FlagsError, e:
- print '%s\\nUsage: %s ARGS\\n%s' % (e, argv[0], FLAGS)
- sys.exit(1)
-
- # Set the logging according to the command-line flag
- logging.getLogger().setLevel(getattr(logging, FLAGS.logging_level))
-
- # If the Credentials don't exist or are invalid run through the native client
- # flow. The Storage object will ensure that if successful the good
- # Credentials will get written back to a file.
- storage = Storage('prediction.dat')
- credentials = storage.get()
- if credentials is None or credentials.invalid:
- credentials = run(FLOW, storage)
-
- # Create an httplib2.Http object to handle our HTTP requests and authorize it
- # with our good Credentials.
- http = httplib2.Http()
- http = credentials.authorize(http)
-
- service = build("prediction", "v1.4", http=http)
-
- try:
-
- # Start training on a data set
- train = service.trainedmodels()
- body = {'id': FLAGS.model_id, 'storageDataLocation': FLAGS.data_file,
- 'storagePMMLLocation': FLAGS.pmml_file}
- start = train.insert(body=body).execute()
-
- print 'Started training'
- pprint.pprint(start)
-
- import time
- # Wait for the training to complete
- while True:
- try:
- # We check the training job is completed. If it is not it will return
- # an error code.
- status = train.get(id=FLAGS.model_id).execute()
- # Job has completed.
- pprint.pprint(status)
- break
- except apiclient.errors.HttpError as error:
- # Training job not yet completed.
- print 'Waiting for training to complete.'
- time.sleep(10)
-
- print 'Training is complete'
-
- # Now make a prediction using that training
- body = {'input': {'csvInstance': [ 5 ]}}
- prediction = train.predict(body=body, id=FLAGS.model_id).execute()
- print 'The prediction is:'
- pprint.pprint(prediction)
-
-
- except AccessTokenRefreshError:
- print ("The credentials have been revoked or expired, please re-run"
- "the application to re-authorize")
-
-if __name__ == '__main__':
- main(sys.argv)
-