| #!/usr/bin/python2.4 |
| # -*- coding: utf-8 -*- |
| # |
| # Copyright (C) 2010 Google Inc. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Simple command-line sample for the Google Prediction API |
| |
| Command-line application that trains on your input data. This sample does |
| the same thing as the Hello Prediction! example. You might want to run |
| the setup.sh script to load the sample data to Google Storage. |
| |
| Usage: |
| $ python prediction.py --object_name="bucket/object" |
| |
| You can also get help on all the command-line flags the program understands |
| by running: |
| |
| $ python prediction.py --help |
| |
| To get detailed log output run: |
| |
| $ python prediction.py --logging_level=DEBUG |
| """ |
| |
| __author__ = 'jcgregorio@google.com (Joe Gregorio)' |
| |
| import apiclient.errors |
| import gflags |
| import httplib2 |
| import logging |
| import pprint |
| import sys |
| |
| from apiclient.discovery import build |
| from oauth2client.file import Storage |
| from oauth2client.client import AccessTokenRefreshError |
| from oauth2client.client import OAuth2WebServerFlow |
| from oauth2client.tools import run |
| |
| FLAGS = gflags.FLAGS |
| |
| # Set up a Flow object to be used if we need to authenticate. This |
| # sample uses OAuth 2.0, and we set up the OAuth2WebServerFlow with |
| # the information it needs to authenticate. Note that it is called |
| # the Web Server Flow, but it can also handle the flow for native |
| # applications <http://code.google.com/apis/accounts/docs/OAuth2.html#IA> |
| # The client_id client_secret are copied from the API Access tab on |
| # the Google APIs Console <http://code.google.com/apis/console>. When |
| # creating credentials for this application be sure to choose an Application |
| # type of "Installed application". |
| FLOW = OAuth2WebServerFlow( |
| client_id='433807057907.apps.googleusercontent.com', |
| client_secret='jigtZpMApkRxncxikFpR+SFg', |
| scope='https://www.googleapis.com/auth/prediction', |
| user_agent='prediction-cmdline-sample/1.0') |
| |
| # The gflags module makes defining command-line options easy for |
| # applications. Run this program with the '--help' argument to see |
| # all the flags that it understands. |
| gflags.DEFINE_enum('logging_level', 'ERROR', |
| ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], |
| 'Set the level of logging detail.') |
| |
| gflags.DEFINE_string('object_name', |
| None, |
| 'Full Google Storage path of csv data (ex bucket/object)') |
| |
| gflags.MarkFlagAsRequired('object_name') |
| |
| def main(argv): |
| # Let the gflags module process the command-line arguments |
| try: |
| argv = FLAGS(argv) |
| except gflags.FlagsError, e: |
| print '%s\\nUsage: %s ARGS\\n%s' % (e, argv[0], FLAGS) |
| sys.exit(1) |
| |
| # Set the logging according to the command-line flag |
| logging.getLogger().setLevel(getattr(logging, FLAGS.logging_level)) |
| |
| # If the Credentials don't exist or are invalid run through the native client |
| # flow. The Storage object will ensure that if successful the good |
| # Credentials will get written back to a file. |
| storage = Storage('prediction.dat') |
| credentials = storage.get() |
| if credentials is None or credentials.invalid: |
| credentials = run(FLOW, storage) |
| |
| # Create an httplib2.Http object to handle our HTTP requests and authorize it |
| # with our good Credentials. |
| http = httplib2.Http() |
| http = credentials.authorize(http) |
| |
| service = build("prediction", "v1.3", http=http) |
| |
| try: |
| |
| # Start training on a data set |
| train = service.training() |
| body = {'id': FLAGS.object_name} |
| start = train.insert(body=body).execute() |
| |
| print 'Started training' |
| pprint.pprint(start) |
| |
| import time |
| # Wait for the training to complete |
| while True: |
| try: |
| # We check the training job is completed. If it is not it will return an error code. |
| status = train.get(data=FLAGS.object_name).execute() |
| # Job has completed. |
| pprint.pprint(status) |
| break |
| except apiclient.errors.HttpError as error: |
| # Training job not yet completed. |
| print 'Waiting for training to complete.' |
| time.sleep(10) |
| |
| print 'Training is complete' |
| |
| # Now make a prediction using that training |
| body = {'input': {'csvInstance': ["mucho bueno"]}} |
| prediction = train.predict(body=body, data=FLAGS.object_name).execute() |
| print 'The prediction is:' |
| pprint.pprint(prediction) |
| |
| |
| except AccessTokenRefreshError: |
| print ("The credentials have been revoked or expired, please re-run" |
| "the application to re-authorize") |
| |
| if __name__ == '__main__': |
| main(sys.argv) |