Upgrade prediction api sample code to v 1.6 and fix training data.
diff --git a/samples/prediction/prediction.py b/samples/prediction/prediction.py
index a130f2a..50d4b5e 100644
--- a/samples/prediction/prediction.py
+++ b/samples/prediction/prediction.py
@@ -22,7 +22,7 @@
the setup.sh script to load the sample data to Google Storage.
Usage:
- $ python prediction.py --object_name="bucket/object" --id="model_id"
+ $ python prediction.py "bucket/object" "model_id" "project_id"
You can also get help on all the command-line flags the program understands
by running:
@@ -43,8 +43,8 @@
import sys
import time
-from googleapiclient import discovery
-from googleapiclient import sample_tools
+from apiclient import discovery
+from apiclient import sample_tools
from oauth2client import client
@@ -55,9 +55,11 @@
# Declare command-line flags.
argparser = argparse.ArgumentParser(add_help=False)
argparser.add_argument('object_name',
- help='Full Google Storage path of csv data (ex bucket/object)')
-argparser.add_argument('id',
- help='Model Id of your choosing to name trained model')
+ help='Full Google Storage path of csv data (ex bucket/object)')
+argparser.add_argument('model_id',
+ help='Model Id of your choosing to name trained model')
+argparser.add_argument('project_id',
+ help='Model Id of your choosing to name trained model')
def print_header(line):
@@ -70,9 +72,15 @@
def main(argv):
+ # If you previously ran this app with an earlier version of the API
+ # or if you change the list of scopes below, revoke your app's permission
+ # here: https://accounts.google.com/IssuedAuthSubTokens
+ # Then re-run the app to re-authorize it.
service, flags = sample_tools.init(
- argv, 'prediction', 'v1.5', __doc__, __file__, parents=[argparser],
- scope='https://www.googleapis.com/auth/prediction')
+ argv, 'prediction', 'v1.6', __doc__, __file__, parents=[argparser],
+ scope=(
+ 'https://www.googleapis.com/auth/prediction',
+ 'https://www.googleapis.com/auth/devstorage.read_only'))
try:
# Get access to the Prediction API.
@@ -80,21 +88,21 @@
# List models.
print_header('Fetching list of first ten models')
- result = papi.list(maxResults=10).execute()
+ result = papi.list(maxResults=10, project=flags.project_id).execute()
print 'List results:'
pprint.pprint(result)
# Start training request on a data set.
print_header('Submitting model training request')
- body = {'id': flags.id, 'storageDataLocation': flags.object_name}
- start = papi.insert(body=body).execute()
+ body = {'id': flags.model_id, 'storageDataLocation': flags.object_name}
+ start = papi.insert(body=body, project=flags.project_id).execute()
print 'Training results:'
pprint.pprint(start)
# Wait for the training to complete.
print_header('Waiting for training to complete')
while True:
- status = papi.get(id=flags.id).execute()
+ status = papi.get(id=flags.model_id, project=flags.project_id).execute()
state = status['trainingStatus']
print 'Training state: ' + state
if state == 'DONE':
@@ -112,25 +120,27 @@
# Describe model.
print_header('Fetching model description')
- result = papi.analyze(id=flags.id).execute()
+ result = papi.analyze(id=flags.model_id, project=flags.project_id).execute()
print 'Analyze results:'
pprint.pprint(result)
- # Make a prediction using the newly trained model.
- print_header('Making a prediction')
- body = {'input': {'csvInstance': ["mucho bueno"]}}
- result = papi.predict(body=body, id=flags.id).execute()
- print 'Prediction results...'
- pprint.pprint(result)
+ # Make some predictions using the newly trained model.
+ print_header('Making some predictions')
+ for sample_text in ['mucho bueno', 'bonjour, mon cher ami']:
+ body = {'input': {'csvInstance': [sample_text]}}
+ result = papi.predict(
+ body=body, id=flags.model_id, project=flags.project_id).execute()
+ print 'Prediction results for "%s"...' % sample_text
+ pprint.pprint(result)
# Delete model.
print_header('Deleting model')
- result = papi.delete(id=flags.id).execute()
+ result = papi.delete(id=flags.model_id, project=flags.project_id).execute()
print 'Model deleted.'
except client.AccessTokenRefreshError:
- print ("The credentials have been revoked or expired, please re-run"
- "the application to re-authorize")
+ print ('The credentials have been revoked or expired, please re-run '
+ 'the application to re-authorize.')
if __name__ == '__main__':