blob: fab5c3b5ebd5d4e4e497b0c75f90bb0b9670845a [file] [log] [blame]
// Copyright (C) 2020 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "operator_context.h"
#include <exception>
#include "awaitable_qe_call.h"
#include "input_channel.h"
#include "output_channel.h"
#include "pyerr.h"
#include "pyerrfmt.h"
#include "pyiter.h"
#include "pylog.h"
#include "pyobj.h"
#include "pyseq.h"
#include "qe_call_terminated.h"
#include "query_execution.h"
namespace dctv {
using std::unique_ptr;
void
OperatorContext::close() noexcept
{
// Slam our channels shut, forcing channel objects to drop their
// references to this OperatorContext.
for (auto& channel : this->input_channels)
channel->close();
this->input_channels.clear();
for (auto& channel : this->output_channels)
channel->close();
this->output_channels.clear();
this->current_request.reset();
this->reset_coroutine_state();
if (this->link_state == LinkState::ORPHANED) {
assume(!this->link_qe);
assume(!this->link_ref);
} else {
assume(this->link_qe);
assume(this->link_ref);
// N.B. remove_operator_on_close may deallocate this!
this->link_qe->remove_operator_on_close(this);
}
}
OperatorContext::~OperatorContext()
{
assume(this->link_state == LinkState::ORPHANED);
this->close();
}
void
OperatorContext::on_output_channel_disconnected() noexcept
{
bool disconnected = !this->precious &&
std::all_of(this->output_channels.begin(),
this->output_channels.end(),
[](const auto& channel) {
return channel->is_disconnected();
});
if (disconnected)
this->close();
}
void
OperatorContext::install(QueryExecution* qe,
int ordinal,
unique_pyref action,
QueryDb* query_db)
{
check_pytype(QueryClasses::get().cls_query_action, action);
unique_op_ref op = make_pyobj<OperatorContext>();
op->block_size = qe->block_size;
op->ordinal = ordinal;
op->action = std::move(action);
op->precious = true_(getattr(op->action, "precious"));
// Start the coroutine running. If the coroutine fails right away,
// we just unwind right away without changing any state.
op->coroutine = call(getattr(op->action, "run_async"), qe);
op->send_function = getattr(op->coroutine, "send");
op->communicate(addref(Py_None));
assume(op->link_state == LinkState::ORPHANED);
op->run_setup(qe, query_db);
assume(op->link_state == LinkState::ORPHANED);
qe->add_operator(std::move(op));
}
void
OperatorContext::run_setup(QueryExecution* qe, QueryDb* query_db)
{
// Subtlety: the operator is allowed to call async_setup() again if
// it throws the first time, so keep running the loop until we
// either finish setup or fail irrecoverably.
for (;;) {
unique_ptr<QeCall> request = std::move(this->current_request);
try {
request->do_operator_setup(qe, this, query_db);
break;
} catch (...) {
this->send_current_exception();
this->check_current_request();
}
}
}
void
OperatorContext::invalidate_score() noexcept
{
if (this->link_state == LinkState::ON_RUN_QUEUE_VALID_SCORE) {
assume(this->link_qe);
this->link_qe->add_to_re_score_queue(this);
assume(this->link_state == LinkState::ON_RUN_QUEUE_NEED_RESCORE);
}
}
void
OperatorContext::queue_pending_block_size_change()
{
this->need_buffer_size_update = true;
this->invalidate_score();
}
Score
OperatorContext::compute_basic_score() const
{
Score score{};
score.state = OperatorState::NOT_RUNNABLE;
score.negated_ordinal = -this->ordinal;
return score;
}
unique_pyref
OperatorContext::turn_crank()
{
try {
this->check_current_request();
const perf::Sampler* sampler = this->link_qe->get_perf_sampler();
perf::Sample before_sample;
perf::Sample after_sample;
if (sampler)
before_sample = sampler->sample();
unique_pyref ret = this->current_request->do_it(this);
if (sampler) {
perf::Sample after_sample = sampler->sample();
this->accumulated_perf += (after_sample - before_sample);
}
return ret;
} catch (...) {
this->current_request.reset();
// send_current_exception() will eventually either re-establish
// current_request or throw. In the former case, we'll let the
// query runner loop re-invoke us.
this->send_current_exception();
return {};
}
}
Vector<QueryKey>
OperatorContext::get_declared_inputs() const
{
// N.B. the _fast makes it work with set objects
return py2vec_fast(getattr(this->action, "inputs"), QueryKey::from_py);
}
Vector<QueryKey>
OperatorContext::get_declared_outputs() const
{
// N.B. the _fast makes it work with set objects
return py2vec_fast(getattr(this->action, "outputs"), QueryKey::from_py);
}
void
OperatorContext::reset_coroutine_state() noexcept
{
// Use a dummy value instead of just setting to nullptr so that we
// don't need a pre-call check in the common case of
// OperatorContext::communicate.
this->coroutine = addref(Py_None);
this->send_function = addref(Py_None);
}
void
OperatorContext::communicate_1(std::pair<GenRet, unique_pyref> gen_ret)
{
unique_ptr<QeCall> request;
auto& [state, ret] = gen_ret;
if (state == GenRet::RETURNED) {
this->reset_coroutine_state();
if (ret != Py_None)
throw_pyerr_fmt(PyExc_RuntimeError,
"operators must return None, not %s",
repr(ret));
Vector<IoSpec> terminal_io_specs;
terminal_io_specs.reserve(this->output_channels.size());
for (auto& output_channel : this->output_channels)
if (output_channel->needs_flush())
terminal_io_specs.emplace_back(
IoTerminalFlush(output_channel.addref()));
request = std::make_unique<QeCallTerminated>(
std::move(terminal_io_specs));
} else {
request =
std::move(ret).addref_as<AwaitableQeCall>()->extract_request();
}
request->setup(this);
this->current_request = std::move(request);
this->invalidate_score();
}
void
OperatorContext::communicate(unique_pyref reply_to_prev_qe_call)
{
std::pair<GenRet, unique_pyref> gen_ret;
{
bool success = false;
FINALLY(if (!success) this->reset_coroutine_state());
gen_ret = call_gen(this->send_function, reply_to_prev_qe_call);
success = true;
}
this->communicate_1(std::move(gen_ret));
}
void
OperatorContext::send_current_exception()
{
if (this->coroutine == Py_None)
throw;
try {
assume(!pyerr_occurred());
_set_pending_cxx_exception_as_pyexception();
assume(pyerr_occurred());
PyExceptionInfo info = PyExceptionInfo::fetch();
assume(!pyerr_occurred());
unique_pyref throw_function = getattr(this->coroutine, "throw");
std::pair<GenRet, unique_pyref> gen_ret;
{
bool success = false;
FINALLY(if (!success) this->reset_coroutine_state());
gen_ret = call_gen(throw_function,
info.type.notnull(),
info.value ?: pyref(Py_None),
info.traceback ?: pyref(Py_None));
success = true;
}
this->communicate_1(std::move(gen_ret));
} catch (...) {
this->send_current_exception();
}
}
int
OperatorContext::py_traverse(visitproc visit, void* arg) const noexcept
{
// N.B. Do *not* visit link_ref! This field is "owned" by our owning
// QueryExecution, and if we visit the field ourselves, we'll
// confuse the GC into thinking we're just part a reference cycle
// and get collected early.
if (this->link_state == LinkState::ORPHANED) {
assume(!this->link_qe);
assume(!this->link_ref);
} else {
assume(this->link_qe);
assume(this->link_ref == this);
}
Py_VISIT(this->action.get());
Py_VISIT(this->coroutine.get());
Py_VISIT(this->send_function.get());
if (this->current_request)
if (int ret = this->current_request->py_traverse(visit, arg))
return ret;
for (auto& input_channel : this->input_channels)
Py_VISIT(input_channel.get()); // NOLINT
for (auto& output_channel : this->output_channels)
Py_VISIT(output_channel.get()); // NOLINT
return 0;
}
int
OperatorContext::py_clear() noexcept
{
this->close();
return 0;
}
Score
OperatorContext::compute_score() const
{
this->check_current_request();
return this->current_request->compute_score(this);
}
void
OperatorContext::check_current_request() const
{
if (!this->current_request)
throw_pyerr_fmt(PyExc_RuntimeError, "operator previously failed");
if (this->need_buffer_size_update) {
this->current_request->refresh_resize_buffers(
this, &this->current_request);
this->need_buffer_size_update = false;
}
}
void
OperatorContext::make_buffer_resize_ops(Vector<IoSpec>* io_specs) const
{
io_specs->clear();
io_specs->reserve(this->output_channels.size());
for (const auto& channel : this->output_channels)
if (channel->is_dynamic_buffer_size_enabled())
io_specs->emplace_back(IoResizeBuffer(channel.addref()));
}
void
OperatorContext::flush_perf_to_qe()
{
this->link_qe->accumulate_perf(this->action, this->accumulated_perf);
}
PyTypeObject OperatorContext::pytype = make_py_type<OperatorContext>(
"dctv._native.OperatorContext",
"Per-operator internal context for query execution",
[](PyTypeObject* t) {});
void
init_operator_context(pyref m)
{
register_type(m, &OperatorContext::pytype);
}
} // namespace dctv