blob: 536d6c8f40ae2d16fb629987fe8ca9ba2743a3ee [file] [log] [blame]
Joe Gregorio11690e02011-10-14 11:05:35 -04001#!/usr/bin/python2.4
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2010 Google Inc.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18"""Simple command-line sample for the Google Prediction API
19
20Command-line application that trains on your input data. This sample does
21the same thing as the Hello Prediction! example. You might want to run
22the setup.sh script to load both the sample data and the pmml file to
23Google Storage.
24
25Usage:
26 $ python prediction_number.py --model_id="foo"
27 --data_file="data_bucket/data_object" --pmml_file="pmml_bucket/pmml_object"
28
29You can also get help on all the command-line flags the program understands
30by running:
31
32 $ python prediction_number.py --help
33
34To get detailed log output run:
35
36 $ python prediction_number.py --logging_level=DEBUG
37"""
38
39__author__ = 'jcgregorio@google.com (Joe Gregorio)'
40
41from apiclient.discovery import build_from_document
42
43import apiclient.errors
44import gflags
45import httplib2
46import logging
47import os
48import pprint
49import sys
50
51from apiclient.discovery import build
52from oauth2client.file import Storage
53from oauth2client.client import AccessTokenRefreshError
54from oauth2client.client import flow_from_clientsecrets
55from oauth2client.tools import run
56
57FLAGS = gflags.FLAGS
58
59# CLIENT_SECRETS, name of a file containing the OAuth 2.0 information for this
60# application, including client_id and client_secret, which are found
61# on the API Access tab on the Google APIs
62# Console <http://code.google.com/apis/console>
63CLIENT_SECRETS = 'client_secrets.json'
64
65# Helpful message to display in the browser if the CLIENT_SECRETS file
66# is missing.
67MISSING_CLIENT_SECRETS_MESSAGE = """
68WARNING: Please configure OAuth 2.0
69
70To make this sample run you will need to populate the client_secrets.json file
71found at:
72
73 %s
74
75with information from the APIs Console <https://code.google.com/apis/console>.
76
77""" % os.path.join(os.path.dirname(__file__), CLIENT_SECRETS)
78
79# Set up a Flow object to be used if we need to authenticate.
80FLOW = flow_from_clientsecrets(CLIENT_SECRETS,
81 scope='https://www.googleapis.com/auth/prediction',
82 message=MISSING_CLIENT_SECRETS_MESSAGE)
83
84# The gflags module makes defining command-line options easy for
85# applications. Run this program with the '--help' argument to see
86# all the flags that it understands.
87gflags.DEFINE_enum('logging_level', 'ERROR',
88 ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
89 'Set the level of logging detail.')
90
91gflags.DEFINE_string('model_id',
92 None,
93 'The unique name for the predictive model (ex foo)')
94
95gflags.DEFINE_string('data_file',
96 None,
97 'Full Google Storage path of csv data (ex bucket/object)')
98
99gflags.DEFINE_string('pmml_file',
100 None,
101 'Full Google Storage path of pmml for '
102 'preprocessing (ex bucket/object)')
103
104gflags.MarkFlagAsRequired('model_id')
105gflags.MarkFlagAsRequired('data_file')
106gflags.MarkFlagAsRequired('pmml_file')
107
108def main(argv):
109 # Let the gflags module process the command-line arguments
110 try:
111 argv = FLAGS(argv)
112 except gflags.FlagsError, e:
113 print '%s\\nUsage: %s ARGS\\n%s' % (e, argv[0], FLAGS)
114 sys.exit(1)
115
116 # Set the logging according to the command-line flag
117 logging.getLogger().setLevel(getattr(logging, FLAGS.logging_level))
118
119 # If the Credentials don't exist or are invalid run through the native client
120 # flow. The Storage object will ensure that if successful the good
121 # Credentials will get written back to a file.
122 storage = Storage('prediction.dat')
123 credentials = storage.get()
124 if credentials is None or credentials.invalid:
125 credentials = run(FLOW, storage)
126
127 # Create an httplib2.Http object to handle our HTTP requests and authorize it
128 # with our good Credentials.
129 http = httplib2.Http()
130 http = credentials.authorize(http)
131
132 service = build("prediction", "v1.4", http=http)
133
134 try:
135
136 # Start training on a data set
137 train = service.trainedmodels()
138 body = {'id': FLAGS.model_id, 'storageDataLocation': FLAGS.data_file,
139 'storagePMMLLocation': FLAGS.pmml_file}
140 start = train.insert(body=body).execute()
141
142 print 'Started training'
143 pprint.pprint(start)
144
145 import time
146 # Wait for the training to complete
147 while True:
148 try:
149 # We check the training job is completed. If it is not it will return
150 # an error code.
151 status = train.get(id=FLAGS.model_id).execute()
152 # Job has completed.
153 pprint.pprint(status)
154 break
155 except apiclient.errors.HttpError as error:
156 # Training job not yet completed.
157 print 'Waiting for training to complete.'
158 time.sleep(10)
159
160 print 'Training is complete'
161
162 # Now make a prediction using that training
163 body = {'input': {'csvInstance': [ 5 ]}}
164 prediction = train.predict(body=body, id=FLAGS.model_id).execute()
165 print 'The prediction is:'
166 pprint.pprint(prediction)
167
168
169 except AccessTokenRefreshError:
170 print ("The credentials have been revoked or expired, please re-run"
171 "the application to re-authorize")
172
173if __name__ == '__main__':
174 main(sys.argv)
175